Reduce Conv2D.__call__ overhead by 45%, Conv3D.__call__ overhead by 30%, and
Conv1D.__call__ overhead by 15% Removes the indirection caused by using the nn_ops.Convolution class rather than nn_ops.convolution. PiperOrigin-RevId: 315920158 Change-Id: Ia5db49a36870b424b70d9d04962b84697d89bf55
This commit is contained in:
parent
52806d3849
commit
17b63987e5
@ -78,10 +78,17 @@ class ConvolutionNodeNameTest(xla_test.XLATestCase):
|
||||
|
||||
xla_names = _GetNodeNames(use_xla=True)
|
||||
no_xla_names = _GetNodeNames(use_xla=False)
|
||||
self.assertListEqual(
|
||||
xla_names,
|
||||
no_xla_names,
|
||||
)
|
||||
|
||||
# CPU path creates some additional nodes to handle dilations.
|
||||
# TODO(b/138804006): Remove this when CPU & GPU support dilations.
|
||||
filtered_no_xla_names = []
|
||||
for name in no_xla_names:
|
||||
if ("dilation_rate" in name or "filter_shape" in name or "stack" in name):
|
||||
continue
|
||||
else:
|
||||
filtered_no_xla_names.append(name)
|
||||
|
||||
self.assertListEqual(xla_names, filtered_no_xla_names)
|
||||
|
||||
def testConv1DNodeNameMatch(self):
|
||||
input_sizes = [8, 16, 3]
|
||||
|
@ -19,6 +19,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import six
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import activations
|
||||
@ -125,6 +128,7 @@ class Conv(Layer):
|
||||
bias_constraint=None,
|
||||
trainable=True,
|
||||
name=None,
|
||||
conv_op=None,
|
||||
**kwargs):
|
||||
super(Conv, self).__init__(
|
||||
trainable=trainable,
|
||||
@ -132,30 +136,22 @@ class Conv(Layer):
|
||||
activity_regularizer=regularizers.get(activity_regularizer),
|
||||
**kwargs)
|
||||
self.rank = rank
|
||||
if filters is not None and not isinstance(filters, int):
|
||||
|
||||
if isinstance(filters, float):
|
||||
filters = int(filters)
|
||||
self.filters = filters
|
||||
self.groups = groups or 1
|
||||
if filters is not None and filters % self.groups != 0:
|
||||
raise ValueError(
|
||||
'The number of filters must be evenly divisible by the number of '
|
||||
'groups. Received: groups={}, filters={}'.format(groups, filters))
|
||||
self.kernel_size = conv_utils.normalize_tuple(
|
||||
kernel_size, rank, 'kernel_size')
|
||||
if not all(self.kernel_size):
|
||||
raise ValueError('The argument `kernel_size` cannot contain 0(s). '
|
||||
'Received: %s' % (kernel_size,))
|
||||
self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
|
||||
self.padding = conv_utils.normalize_padding(padding)
|
||||
if (self.padding == 'causal' and not isinstance(self,
|
||||
(Conv1D, SeparableConv1D))):
|
||||
raise ValueError('Causal padding is only supported for `Conv1D`'
|
||||
'and ``SeparableConv1D`.')
|
||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||
self.dilation_rate = conv_utils.normalize_tuple(
|
||||
dilation_rate, rank, 'dilation_rate')
|
||||
|
||||
self.activation = activations.get(activation)
|
||||
self.use_bias = use_bias
|
||||
|
||||
self.kernel_initializer = initializers.get(kernel_initializer)
|
||||
self.bias_initializer = initializers.get(bias_initializer)
|
||||
self.kernel_regularizer = regularizers.get(kernel_regularizer)
|
||||
@ -164,6 +160,28 @@ class Conv(Layer):
|
||||
self.bias_constraint = constraints.get(bias_constraint)
|
||||
self.input_spec = InputSpec(min_ndim=self.rank + 2)
|
||||
|
||||
self._validate_init()
|
||||
self._is_causal = self.padding == 'causal'
|
||||
self._channels_first = self.data_format == 'channels_first'
|
||||
self._tf_data_format = conv_utils.convert_data_format(
|
||||
self.data_format, self.rank + 2)
|
||||
|
||||
def _validate_init(self):
|
||||
if self.filters is not None and self.filters % self.groups != 0:
|
||||
raise ValueError(
|
||||
'The number of filters must be evenly divisible by the number of '
|
||||
'groups. Received: groups={}, filters={}'.format(
|
||||
self.groups, self.filters))
|
||||
|
||||
if not all(self.kernel_size):
|
||||
raise ValueError('The argument `kernel_size` cannot contain 0(s). '
|
||||
'Received: %s' % (self.kernel_size,))
|
||||
|
||||
if (self.padding == 'causal' and not isinstance(self,
|
||||
(Conv1D, SeparableConv1D))):
|
||||
raise ValueError('Causal padding is only supported for `Conv1D`'
|
||||
'and `SeparableConv1D`.')
|
||||
|
||||
def build(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape)
|
||||
input_channel = self._get_input_channel(input_shape)
|
||||
@ -199,64 +217,53 @@ class Conv(Layer):
|
||||
self.input_spec = InputSpec(min_ndim=self.rank + 2,
|
||||
axes={channel_axis: input_channel})
|
||||
|
||||
self._build_conv_op_data_shape = input_shape[-(self.rank + 1):]
|
||||
self._build_input_channel = input_channel
|
||||
self._padding_op = self._get_padding_op()
|
||||
self._conv_op_data_format = conv_utils.convert_data_format(
|
||||
self.data_format, self.rank + 2)
|
||||
self._convolution_op = nn_ops.Convolution(
|
||||
input_shape,
|
||||
filter_shape=self.kernel.shape,
|
||||
dilation_rate=self.dilation_rate,
|
||||
strides=self.strides,
|
||||
padding=self._padding_op,
|
||||
data_format=self._conv_op_data_format,
|
||||
num_spatial_dims=self.rank)
|
||||
# Convert Keras formats to TF native formats.
|
||||
if self.padding == 'causal':
|
||||
tf_padding = 'VALID' # Causal padding handled in `call`.
|
||||
elif isinstance(self.padding, six.string_types):
|
||||
tf_padding = self.padding.upper()
|
||||
else:
|
||||
tf_padding = self.padding
|
||||
tf_dilations = self.dilation_rate
|
||||
tf_strides = self.strides
|
||||
|
||||
tf_op_name = self.__class__.__name__
|
||||
if tf_op_name == 'Conv1D':
|
||||
tf_op_name = 'conv1d' # Backwards compat.
|
||||
|
||||
self._convolution_op = functools.partial(
|
||||
nn_ops.convolution_v2,
|
||||
strides=tf_strides,
|
||||
padding=tf_padding,
|
||||
dilations=tf_dilations,
|
||||
data_format=self._tf_data_format,
|
||||
name=tf_op_name)
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
if self._recreate_conv_op(inputs):
|
||||
self._convolution_op = nn_ops.Convolution(
|
||||
inputs.shape,
|
||||
filter_shape=self.kernel.shape,
|
||||
dilation_rate=self.dilation_rate,
|
||||
strides=self.strides,
|
||||
padding=self._padding_op,
|
||||
data_format=self._conv_op_data_format,
|
||||
num_spatial_dims=self.rank)
|
||||
self._build_conv_op_data_shape = inputs.shape[-(self.rank + 1):]
|
||||
|
||||
# Apply causal padding to inputs for Conv1D.
|
||||
if self.padding == 'causal' and self.__class__.__name__ == 'Conv1D':
|
||||
if self._is_causal: # Apply causal padding to inputs for Conv1D.
|
||||
inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs))
|
||||
|
||||
outputs = self._convolution_op(inputs, self.kernel)
|
||||
|
||||
if self.use_bias:
|
||||
outputs_rank = outputs.shape.ndims
|
||||
if self.data_format == 'channels_first':
|
||||
if self.rank == 1:
|
||||
# nn.bias_add does not accept a 1D input tensor.
|
||||
bias = array_ops.reshape(self.bias, (1, self.filters, 1))
|
||||
outputs += bias
|
||||
else:
|
||||
if outputs_rank is not None and outputs_rank > 2 + self.rank:
|
||||
# larger batch rank
|
||||
outputs = nn_ops.squeeze_batch_dims(
|
||||
outputs,
|
||||
lambda o: nn.bias_add(o, self.bias, data_format='NCHW'),
|
||||
inner_rank=self.rank + 1)
|
||||
else:
|
||||
outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
|
||||
output_rank = outputs.shape.rank
|
||||
if self.rank == 1 and self._channels_first:
|
||||
# nn.bias_add does not accept a 1D input tensor.
|
||||
bias = array_ops.reshape(self.bias, (1, self.filters, 1))
|
||||
outputs += bias
|
||||
else:
|
||||
if outputs_rank is not None and outputs_rank > 2 + self.rank:
|
||||
# larger batch rank
|
||||
# Handle multiple batch dimensions.
|
||||
if output_rank is not None and output_rank > 2 + self.rank:
|
||||
|
||||
def _apply_fn(o):
|
||||
return nn.bias_add(o, self.bias, data_format=self._tf_data_format)
|
||||
|
||||
outputs = nn_ops.squeeze_batch_dims(
|
||||
outputs,
|
||||
lambda o: nn.bias_add(o, self.bias, data_format='NHWC'),
|
||||
inner_rank=self.rank + 1)
|
||||
outputs, _apply_fn, inner_rank=self.rank + 1)
|
||||
else:
|
||||
outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
|
||||
outputs = nn.bias_add(
|
||||
outputs, self.bias, data_format=self._tf_data_format)
|
||||
|
||||
if self.activation is not None:
|
||||
return self.activation(outputs)
|
||||
@ -286,6 +293,9 @@ class Conv(Layer):
|
||||
input_shape[:batch_rank] + [self.filters] +
|
||||
self._spatial_output_shape(input_shape[batch_rank + 1:]))
|
||||
|
||||
def _recreate_conv_op(self, inputs): # pylint: disable=unused-argument
|
||||
return False
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'filters':
|
||||
@ -359,27 +369,6 @@ class Conv(Layer):
|
||||
op_padding = op_padding.upper()
|
||||
return op_padding
|
||||
|
||||
def _recreate_conv_op(self, inputs):
|
||||
"""Recreate conv_op if necessary.
|
||||
|
||||
Check if the input_shape in call() is different from that in build().
|
||||
If the most-specific input shape describing the build and call shapes is not
|
||||
equal to the shape we currently built with, then we need to rebuild the
|
||||
_convolution_op to avoid incorrect behavior.
|
||||
|
||||
Args:
|
||||
inputs: The input data to call() method.
|
||||
|
||||
Returns:
|
||||
`True` or `False` to indicate whether to recreate the conv_op.
|
||||
"""
|
||||
call_data_shape = inputs.shape[-(self.rank + 1):]
|
||||
# If the most specific compatible shape between _build_data_shape and
|
||||
# call_data_shape is not _build_data_shape then we must re-build.
|
||||
return (self._build_conv_op_data_shape
|
||||
!= self._build_conv_op_data_shape.most_specific_compatible_shape(
|
||||
call_data_shape))
|
||||
|
||||
|
||||
@keras_export('keras.layers.Conv1D', 'keras.layers.Convolution1D')
|
||||
class Conv1D(Conv):
|
||||
@ -572,74 +561,60 @@ class Conv2D(Conv):
|
||||
|
||||
|
||||
Arguments:
|
||||
filters: Integer, the dimensionality of the output space
|
||||
(i.e. the number of output filters in the convolution).
|
||||
kernel_size: An integer or tuple/list of 2 integers, specifying the
|
||||
height and width of the 2D convolution window.
|
||||
Can be a single integer to specify the same value for
|
||||
all spatial dimensions.
|
||||
strides: An integer or tuple/list of 2 integers,
|
||||
specifying the strides of the convolution along the height and width.
|
||||
Can be a single integer to specify the same value for
|
||||
all spatial dimensions.
|
||||
Specifying any stride value != 1 is incompatible with specifying
|
||||
any `dilation_rate` value != 1.
|
||||
filters: Integer, the dimensionality of the output space (i.e. the number of
|
||||
output filters in the convolution).
|
||||
kernel_size: An integer or tuple/list of 2 integers, specifying the height
|
||||
and width of the 2D convolution window. Can be a single integer to specify
|
||||
the same value for all spatial dimensions.
|
||||
strides: An integer or tuple/list of 2 integers, specifying the strides of
|
||||
the convolution along the height and width. Can be a single integer to
|
||||
specify the same value for all spatial dimensions. Specifying any stride
|
||||
value != 1 is incompatible with specifying any `dilation_rate` value != 1.
|
||||
padding: one of `"valid"` or `"same"` (case-insensitive).
|
||||
data_format: A string,
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch_size, height, width, channels)` while `channels_first`
|
||||
corresponds to inputs with shape
|
||||
`(batch_size, channels, height, width)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be `channels_last`.
|
||||
dilation_rate: an integer or tuple/list of 2 integers, specifying
|
||||
the dilation rate to use for dilated convolution.
|
||||
Can be a single integer to specify the same value for
|
||||
all spatial dimensions.
|
||||
Currently, specifying any `dilation_rate` value != 1 is
|
||||
incompatible with specifying any stride value != 1.
|
||||
data_format: A string, one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs. `channels_last` corresponds
|
||||
to inputs with shape `(batch_size, height, width, channels)` while
|
||||
`channels_first` corresponds to inputs with shape `(batch_size, channels,
|
||||
height, width)`. It defaults to the `image_data_format` value found in
|
||||
your Keras config file at `~/.keras/keras.json`. If you never set it, then
|
||||
it will be `channels_last`.
|
||||
dilation_rate: an integer or tuple/list of 2 integers, specifying the
|
||||
dilation rate to use for dilated convolution. Can be a single integer to
|
||||
specify the same value for all spatial dimensions. Currently, specifying
|
||||
any `dilation_rate` value != 1 is incompatible with specifying any stride
|
||||
value != 1.
|
||||
groups: A positive integer specifying the number of groups in which the
|
||||
input is split along the channel axis. Each group is convolved
|
||||
separately with `filters / groups` filters. The output is the
|
||||
concatenation of all the `groups` results along the channel axis.
|
||||
Input channels and `filters` must both be divisible by `groups`.
|
||||
activation: Activation function to use.
|
||||
If you don't specify anything, no activation is applied (
|
||||
see `keras.activations`).
|
||||
input is split along the channel axis. Each group is convolved separately
|
||||
with `filters / groups` filters. The output is the concatenation of all
|
||||
the `groups` results along the channel axis. Input channels and `filters`
|
||||
must both be divisible by `groups`.
|
||||
activation: Activation function to use. If you don't specify anything, no
|
||||
activation is applied (see `keras.activations`).
|
||||
use_bias: Boolean, whether the layer uses a bias vector.
|
||||
kernel_initializer: Initializer for the `kernel` weights matrix (
|
||||
see `keras.initializers`).
|
||||
bias_initializer: Initializer for the bias vector (
|
||||
see `keras.initializers`).
|
||||
kernel_regularizer: Regularizer function applied to
|
||||
the `kernel` weights matrix (see `keras.regularizers`).
|
||||
bias_regularizer: Regularizer function applied to the bias vector (
|
||||
see `keras.regularizers`).
|
||||
activity_regularizer: Regularizer function applied to
|
||||
the output of the layer (its "activation") (
|
||||
see `keras.regularizers`).
|
||||
kernel_constraint: Constraint function applied to the kernel matrix (
|
||||
see `keras.constraints`).
|
||||
bias_constraint: Constraint function applied to the bias vector (
|
||||
see `keras.constraints`).
|
||||
|
||||
kernel_initializer: Initializer for the `kernel` weights matrix (see
|
||||
`keras.initializers`).
|
||||
bias_initializer: Initializer for the bias vector (see
|
||||
`keras.initializers`).
|
||||
kernel_regularizer: Regularizer function applied to the `kernel` weights
|
||||
matrix (see `keras.regularizers`).
|
||||
bias_regularizer: Regularizer function applied to the bias vector (see
|
||||
`keras.regularizers`).
|
||||
activity_regularizer: Regularizer function applied to the output of the
|
||||
layer (its "activation") (see `keras.regularizers`).
|
||||
kernel_constraint: Constraint function applied to the kernel matrix (see
|
||||
`keras.constraints`).
|
||||
bias_constraint: Constraint function applied to the bias vector (see
|
||||
`keras.constraints`).
|
||||
Input shape:
|
||||
4+D tensor with shape:
|
||||
`batch_shape + (channels, rows, cols)` if `data_format='channels_first'`
|
||||
or 4+D tensor with shape:
|
||||
`batch_shape + (rows, cols, channels)` if `data_format='channels_last'`.
|
||||
|
||||
4+D tensor with shape: `batch_shape + (channels, rows, cols)` if
|
||||
`data_format='channels_first'`
|
||||
or 4+D tensor with shape: `batch_shape + (rows, cols, channels)` if
|
||||
`data_format='channels_last'`.
|
||||
Output shape:
|
||||
4+D tensor with shape:
|
||||
`batch_shape + (filters, new_rows, new_cols)` if
|
||||
`data_format='channels_first'` or 4+D tensor with shape:
|
||||
`batch_shape + (new_rows, new_cols, filters)` if
|
||||
`data_format='channels_last'`.
|
||||
|
||||
`rows` and `cols` values might have changed due to padding.
|
||||
4+D tensor with shape: `batch_shape + (filters, new_rows, new_cols)` if
|
||||
`data_format='channels_first'` or 4+D tensor with shape: `batch_shape +
|
||||
(new_rows, new_cols, filters)` if `data_format='channels_last'`. `rows`
|
||||
and `cols` values might have changed due to padding.
|
||||
|
||||
Returns:
|
||||
A tensor of rank 4+ representing
|
||||
@ -727,79 +702,63 @@ class Conv3D(Conv):
|
||||
(4, 7, 26, 26, 26, 2)
|
||||
|
||||
Arguments:
|
||||
filters: Integer, the dimensionality of the output space
|
||||
(i.e. the number of output filters in the convolution).
|
||||
kernel_size: An integer or tuple/list of 3 integers, specifying the
|
||||
depth, height and width of the 3D convolution window.
|
||||
Can be a single integer to specify the same value for
|
||||
all spatial dimensions.
|
||||
strides: An integer or tuple/list of 3 integers,
|
||||
specifying the strides of the convolution along each spatial
|
||||
dimension.
|
||||
Can be a single integer to specify the same value for
|
||||
all spatial dimensions.
|
||||
Specifying any stride value != 1 is incompatible with specifying
|
||||
any `dilation_rate` value != 1.
|
||||
filters: Integer, the dimensionality of the output space (i.e. the number of
|
||||
output filters in the convolution).
|
||||
kernel_size: An integer or tuple/list of 3 integers, specifying the depth,
|
||||
height and width of the 3D convolution window. Can be a single integer to
|
||||
specify the same value for all spatial dimensions.
|
||||
strides: An integer or tuple/list of 3 integers, specifying the strides of
|
||||
the convolution along each spatial dimension. Can be a single integer to
|
||||
specify the same value for all spatial dimensions. Specifying any stride
|
||||
value != 1 is incompatible with specifying any `dilation_rate` value != 1.
|
||||
padding: one of `"valid"` or `"same"` (case-insensitive).
|
||||
data_format: A string,
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`batch_shape + (spatial_dim1, spatial_dim2, spatial_dim3, channels)`
|
||||
while `channels_first` corresponds to inputs with shape
|
||||
`batch_shape + (channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
dilation_rate: an integer or tuple/list of 3 integers, specifying
|
||||
the dilation rate to use for dilated convolution.
|
||||
Can be a single integer to specify the same value for
|
||||
all spatial dimensions.
|
||||
Currently, specifying any `dilation_rate` value != 1 is
|
||||
incompatible with specifying any stride value != 1.
|
||||
data_format: A string, one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs. `channels_last` corresponds
|
||||
to inputs with shape `batch_shape + (spatial_dim1, spatial_dim2,
|
||||
spatial_dim3, channels)` while `channels_first` corresponds to inputs with
|
||||
shape `batch_shape + (channels, spatial_dim1, spatial_dim2,
|
||||
spatial_dim3)`. It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`. If you never set it, then it
|
||||
will be "channels_last".
|
||||
dilation_rate: an integer or tuple/list of 3 integers, specifying the
|
||||
dilation rate to use for dilated convolution. Can be a single integer to
|
||||
specify the same value for all spatial dimensions. Currently, specifying
|
||||
any `dilation_rate` value != 1 is incompatible with specifying any stride
|
||||
value != 1.
|
||||
groups: A positive integer specifying the number of groups in which the
|
||||
input is split along the channel axis. Each group is convolved
|
||||
separately with `filters / groups` filters. The output is the
|
||||
concatenation of all the `groups` results along the channel axis.
|
||||
Input channels and `filters` must both be divisible by `groups`.
|
||||
activation: Activation function to use.
|
||||
If you don't specify anything, no activation is applied (
|
||||
see `keras.activations`).
|
||||
input is split along the channel axis. Each group is convolved separately
|
||||
with `filters / groups` filters. The output is the concatenation of all
|
||||
the `groups` results along the channel axis. Input channels and `filters`
|
||||
must both be divisible by `groups`.
|
||||
activation: Activation function to use. If you don't specify anything, no
|
||||
activation is applied (see `keras.activations`).
|
||||
use_bias: Boolean, whether the layer uses a bias vector.
|
||||
kernel_initializer: Initializer for the `kernel` weights matrix (
|
||||
see `keras.initializers`).
|
||||
bias_initializer: Initializer for the bias vector (
|
||||
see `keras.initializers`).
|
||||
kernel_regularizer: Regularizer function applied to
|
||||
the `kernel` weights matrix (
|
||||
see `keras.regularizers`).
|
||||
bias_regularizer: Regularizer function applied to the bias vector (
|
||||
see `keras.regularizers`).
|
||||
activity_regularizer: Regularizer function applied to
|
||||
the output of the layer (its "activation") (
|
||||
see `keras.regularizers`).
|
||||
kernel_constraint: Constraint function applied to the kernel matrix (
|
||||
see `keras.constraints`).
|
||||
bias_constraint: Constraint function applied to the bias vector (
|
||||
see `keras.constraints`).
|
||||
|
||||
kernel_initializer: Initializer for the `kernel` weights matrix (see
|
||||
`keras.initializers`).
|
||||
bias_initializer: Initializer for the bias vector (see
|
||||
`keras.initializers`).
|
||||
kernel_regularizer: Regularizer function applied to the `kernel` weights
|
||||
matrix (see `keras.regularizers`).
|
||||
bias_regularizer: Regularizer function applied to the bias vector (see
|
||||
`keras.regularizers`).
|
||||
activity_regularizer: Regularizer function applied to the output of the
|
||||
layer (its "activation") (see `keras.regularizers`).
|
||||
kernel_constraint: Constraint function applied to the kernel matrix (see
|
||||
`keras.constraints`).
|
||||
bias_constraint: Constraint function applied to the bias vector (see
|
||||
`keras.constraints`).
|
||||
Input shape:
|
||||
5+D tensor with shape:
|
||||
`batch_shape + (channels, conv_dim1, conv_dim2, conv_dim3)` if
|
||||
data_format='channels_first'
|
||||
or 5+D tensor with shape:
|
||||
`batch_shape + (conv_dim1, conv_dim2, conv_dim3, channels)` if
|
||||
data_format='channels_last'.
|
||||
|
||||
5+D tensor with shape: `batch_shape + (channels, conv_dim1, conv_dim2,
|
||||
conv_dim3)` if data_format='channels_first'
|
||||
or 5+D tensor with shape: `batch_shape + (conv_dim1, conv_dim2, conv_dim3,
|
||||
channels)` if data_format='channels_last'.
|
||||
Output shape:
|
||||
5+D tensor with shape:
|
||||
`batch_shape + (filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if
|
||||
data_format='channels_first'
|
||||
or 5+D tensor with shape:
|
||||
`batch_shape + (new_conv_dim1, new_conv_dim2, new_conv_dim3, filters)` if
|
||||
data_format='channels_last'.
|
||||
`new_conv_dim1`, `new_conv_dim2` and `new_conv_dim3` values might have
|
||||
changed due to padding.
|
||||
5+D tensor with shape: `batch_shape + (filters, new_conv_dim1,
|
||||
new_conv_dim2, new_conv_dim3)` if data_format='channels_first'
|
||||
or 5+D tensor with shape: `batch_shape + (new_conv_dim1, new_conv_dim2,
|
||||
new_conv_dim3, filters)` if data_format='channels_last'. `new_conv_dim1`,
|
||||
`new_conv_dim2` and `new_conv_dim3` values might have changed due to
|
||||
padding.
|
||||
|
||||
Returns:
|
||||
A tensor of rank 5+ representing
|
||||
|
@ -238,19 +238,6 @@ class Conv2DTest(keras_parameterized.TestCase):
|
||||
self._run_test(kwargs, expected_output_shape)
|
||||
self._run_test_extra_batch_dim(kwargs, expected_output_shape)
|
||||
|
||||
def test_conv2d_op_not_recreated_on_different_batch_shape(self):
|
||||
layer = keras.layers.Conv2D(2, 3)
|
||||
layer(np.ones((1, 28, 28, 2)))
|
||||
# pylint: disable=protected-access
|
||||
old_conv_op = layer._convolution_op
|
||||
# Expand batch to rank-2 shape (5, 5)
|
||||
layer(np.ones((5, 5, 28, 28, 2)))
|
||||
self.assertEqual(old_conv_op, layer._convolution_op)
|
||||
layer(np.ones((1, 30, 30, 2)))
|
||||
# 'HW' changed, so the conv object is rebuilt
|
||||
self.assertNotEqual(old_conv_op, layer._convolution_op)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def test_conv2d_regularizers(self):
|
||||
kwargs = {
|
||||
'filters': 3,
|
||||
|
Loading…
Reference in New Issue
Block a user