Reduce Conv2D.__call__ overhead by 45%, Conv3D.__call__ overhead by 30%, and

Conv1D.__call__ overhead by 15%

Removes the indirection caused by using the nn_ops.Convolution class rather
than nn_ops.convolution.

PiperOrigin-RevId: 315920158
Change-Id: Ia5db49a36870b424b70d9d04962b84697d89bf55
This commit is contained in:
Thomas O'Malley 2020-06-11 09:59:47 -07:00 committed by TensorFlower Gardener
parent 52806d3849
commit 17b63987e5
3 changed files with 181 additions and 228 deletions

View File

@ -78,10 +78,17 @@ class ConvolutionNodeNameTest(xla_test.XLATestCase):
xla_names = _GetNodeNames(use_xla=True)
no_xla_names = _GetNodeNames(use_xla=False)
self.assertListEqual(
xla_names,
no_xla_names,
)
# CPU path creates some additional nodes to handle dilations.
# TODO(b/138804006): Remove this when CPU & GPU support dilations.
filtered_no_xla_names = []
for name in no_xla_names:
if ("dilation_rate" in name or "filter_shape" in name or "stack" in name):
continue
else:
filtered_no_xla_names.append(name)
self.assertListEqual(xla_names, filtered_no_xla_names)
def testConv1DNodeNameMatch(self):
input_sizes = [8, 16, 3]

View File

@ -19,6 +19,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import six
from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import activations
@ -125,6 +128,7 @@ class Conv(Layer):
bias_constraint=None,
trainable=True,
name=None,
conv_op=None,
**kwargs):
super(Conv, self).__init__(
trainable=trainable,
@ -132,30 +136,22 @@ class Conv(Layer):
activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
self.rank = rank
if filters is not None and not isinstance(filters, int):
if isinstance(filters, float):
filters = int(filters)
self.filters = filters
self.groups = groups or 1
if filters is not None and filters % self.groups != 0:
raise ValueError(
'The number of filters must be evenly divisible by the number of '
'groups. Received: groups={}, filters={}'.format(groups, filters))
self.kernel_size = conv_utils.normalize_tuple(
kernel_size, rank, 'kernel_size')
if not all(self.kernel_size):
raise ValueError('The argument `kernel_size` cannot contain 0(s). '
'Received: %s' % (kernel_size,))
self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
self.padding = conv_utils.normalize_padding(padding)
if (self.padding == 'causal' and not isinstance(self,
(Conv1D, SeparableConv1D))):
raise ValueError('Causal padding is only supported for `Conv1D`'
'and ``SeparableConv1D`.')
self.data_format = conv_utils.normalize_data_format(data_format)
self.dilation_rate = conv_utils.normalize_tuple(
dilation_rate, rank, 'dilation_rate')
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
@ -164,6 +160,28 @@ class Conv(Layer):
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=self.rank + 2)
self._validate_init()
self._is_causal = self.padding == 'causal'
self._channels_first = self.data_format == 'channels_first'
self._tf_data_format = conv_utils.convert_data_format(
self.data_format, self.rank + 2)
def _validate_init(self):
if self.filters is not None and self.filters % self.groups != 0:
raise ValueError(
'The number of filters must be evenly divisible by the number of '
'groups. Received: groups={}, filters={}'.format(
self.groups, self.filters))
if not all(self.kernel_size):
raise ValueError('The argument `kernel_size` cannot contain 0(s). '
'Received: %s' % (self.kernel_size,))
if (self.padding == 'causal' and not isinstance(self,
(Conv1D, SeparableConv1D))):
raise ValueError('Causal padding is only supported for `Conv1D`'
'and `SeparableConv1D`.')
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
input_channel = self._get_input_channel(input_shape)
@ -199,64 +217,53 @@ class Conv(Layer):
self.input_spec = InputSpec(min_ndim=self.rank + 2,
axes={channel_axis: input_channel})
self._build_conv_op_data_shape = input_shape[-(self.rank + 1):]
self._build_input_channel = input_channel
self._padding_op = self._get_padding_op()
self._conv_op_data_format = conv_utils.convert_data_format(
self.data_format, self.rank + 2)
self._convolution_op = nn_ops.Convolution(
input_shape,
filter_shape=self.kernel.shape,
dilation_rate=self.dilation_rate,
strides=self.strides,
padding=self._padding_op,
data_format=self._conv_op_data_format,
num_spatial_dims=self.rank)
# Convert Keras formats to TF native formats.
if self.padding == 'causal':
tf_padding = 'VALID' # Causal padding handled in `call`.
elif isinstance(self.padding, six.string_types):
tf_padding = self.padding.upper()
else:
tf_padding = self.padding
tf_dilations = self.dilation_rate
tf_strides = self.strides
tf_op_name = self.__class__.__name__
if tf_op_name == 'Conv1D':
tf_op_name = 'conv1d' # Backwards compat.
self._convolution_op = functools.partial(
nn_ops.convolution_v2,
strides=tf_strides,
padding=tf_padding,
dilations=tf_dilations,
data_format=self._tf_data_format,
name=tf_op_name)
self.built = True
def call(self, inputs):
if self._recreate_conv_op(inputs):
self._convolution_op = nn_ops.Convolution(
inputs.shape,
filter_shape=self.kernel.shape,
dilation_rate=self.dilation_rate,
strides=self.strides,
padding=self._padding_op,
data_format=self._conv_op_data_format,
num_spatial_dims=self.rank)
self._build_conv_op_data_shape = inputs.shape[-(self.rank + 1):]
# Apply causal padding to inputs for Conv1D.
if self.padding == 'causal' and self.__class__.__name__ == 'Conv1D':
if self._is_causal: # Apply causal padding to inputs for Conv1D.
inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs))
outputs = self._convolution_op(inputs, self.kernel)
if self.use_bias:
outputs_rank = outputs.shape.ndims
if self.data_format == 'channels_first':
if self.rank == 1:
# nn.bias_add does not accept a 1D input tensor.
bias = array_ops.reshape(self.bias, (1, self.filters, 1))
outputs += bias
else:
if outputs_rank is not None and outputs_rank > 2 + self.rank:
# larger batch rank
outputs = nn_ops.squeeze_batch_dims(
outputs,
lambda o: nn.bias_add(o, self.bias, data_format='NCHW'),
inner_rank=self.rank + 1)
else:
outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
output_rank = outputs.shape.rank
if self.rank == 1 and self._channels_first:
# nn.bias_add does not accept a 1D input tensor.
bias = array_ops.reshape(self.bias, (1, self.filters, 1))
outputs += bias
else:
if outputs_rank is not None and outputs_rank > 2 + self.rank:
# larger batch rank
# Handle multiple batch dimensions.
if output_rank is not None and output_rank > 2 + self.rank:
def _apply_fn(o):
return nn.bias_add(o, self.bias, data_format=self._tf_data_format)
outputs = nn_ops.squeeze_batch_dims(
outputs,
lambda o: nn.bias_add(o, self.bias, data_format='NHWC'),
inner_rank=self.rank + 1)
outputs, _apply_fn, inner_rank=self.rank + 1)
else:
outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
outputs = nn.bias_add(
outputs, self.bias, data_format=self._tf_data_format)
if self.activation is not None:
return self.activation(outputs)
@ -286,6 +293,9 @@ class Conv(Layer):
input_shape[:batch_rank] + [self.filters] +
self._spatial_output_shape(input_shape[batch_rank + 1:]))
def _recreate_conv_op(self, inputs): # pylint: disable=unused-argument
return False
def get_config(self):
config = {
'filters':
@ -359,27 +369,6 @@ class Conv(Layer):
op_padding = op_padding.upper()
return op_padding
def _recreate_conv_op(self, inputs):
"""Recreate conv_op if necessary.
Check if the input_shape in call() is different from that in build().
If the most-specific input shape describing the build and call shapes is not
equal to the shape we currently built with, then we need to rebuild the
_convolution_op to avoid incorrect behavior.
Args:
inputs: The input data to call() method.
Returns:
`True` or `False` to indicate whether to recreate the conv_op.
"""
call_data_shape = inputs.shape[-(self.rank + 1):]
# If the most specific compatible shape between _build_data_shape and
# call_data_shape is not _build_data_shape then we must re-build.
return (self._build_conv_op_data_shape
!= self._build_conv_op_data_shape.most_specific_compatible_shape(
call_data_shape))
@keras_export('keras.layers.Conv1D', 'keras.layers.Convolution1D')
class Conv1D(Conv):
@ -572,74 +561,60 @@ class Conv2D(Conv):
Arguments:
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
filters: Integer, the dimensionality of the output space (i.e. the number of
output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the height
and width of the 2D convolution window. Can be a single integer to specify
the same value for all spatial dimensions.
strides: An integer or tuple/list of 2 integers, specifying the strides of
the convolution along the height and width. Can be a single integer to
specify the same value for all spatial dimensions. Specifying any stride
value != 1 is incompatible with specifying any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch_size, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch_size, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `channels_last`.
dilation_rate: an integer or tuple/list of 2 integers, specifying
the dilation rate to use for dilated convolution.
Can be a single integer to specify the same value for
all spatial dimensions.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any stride value != 1.
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs. `channels_last` corresponds
to inputs with shape `(batch_size, height, width, channels)` while
`channels_first` corresponds to inputs with shape `(batch_size, channels,
height, width)`. It defaults to the `image_data_format` value found in
your Keras config file at `~/.keras/keras.json`. If you never set it, then
it will be `channels_last`.
dilation_rate: an integer or tuple/list of 2 integers, specifying the
dilation rate to use for dilated convolution. Can be a single integer to
specify the same value for all spatial dimensions. Currently, specifying
any `dilation_rate` value != 1 is incompatible with specifying any stride
value != 1.
groups: A positive integer specifying the number of groups in which the
input is split along the channel axis. Each group is convolved
separately with `filters / groups` filters. The output is the
concatenation of all the `groups` results along the channel axis.
Input channels and `filters` must both be divisible by `groups`.
activation: Activation function to use.
If you don't specify anything, no activation is applied (
see `keras.activations`).
input is split along the channel axis. Each group is convolved separately
with `filters / groups` filters. The output is the concatenation of all
the `groups` results along the channel axis. Input channels and `filters`
must both be divisible by `groups`.
activation: Activation function to use. If you don't specify anything, no
activation is applied (see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix (
see `keras.initializers`).
bias_initializer: Initializer for the bias vector (
see `keras.initializers`).
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix (see `keras.regularizers`).
bias_regularizer: Regularizer function applied to the bias vector (
see `keras.regularizers`).
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation") (
see `keras.regularizers`).
kernel_constraint: Constraint function applied to the kernel matrix (
see `keras.constraints`).
bias_constraint: Constraint function applied to the bias vector (
see `keras.constraints`).
kernel_initializer: Initializer for the `kernel` weights matrix (see
`keras.initializers`).
bias_initializer: Initializer for the bias vector (see
`keras.initializers`).
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix (see `keras.regularizers`).
bias_regularizer: Regularizer function applied to the bias vector (see
`keras.regularizers`).
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation") (see `keras.regularizers`).
kernel_constraint: Constraint function applied to the kernel matrix (see
`keras.constraints`).
bias_constraint: Constraint function applied to the bias vector (see
`keras.constraints`).
Input shape:
4+D tensor with shape:
`batch_shape + (channels, rows, cols)` if `data_format='channels_first'`
or 4+D tensor with shape:
`batch_shape + (rows, cols, channels)` if `data_format='channels_last'`.
4+D tensor with shape: `batch_shape + (channels, rows, cols)` if
`data_format='channels_first'`
or 4+D tensor with shape: `batch_shape + (rows, cols, channels)` if
`data_format='channels_last'`.
Output shape:
4+D tensor with shape:
`batch_shape + (filters, new_rows, new_cols)` if
`data_format='channels_first'` or 4+D tensor with shape:
`batch_shape + (new_rows, new_cols, filters)` if
`data_format='channels_last'`.
`rows` and `cols` values might have changed due to padding.
4+D tensor with shape: `batch_shape + (filters, new_rows, new_cols)` if
`data_format='channels_first'` or 4+D tensor with shape: `batch_shape +
(new_rows, new_cols, filters)` if `data_format='channels_last'`. `rows`
and `cols` values might have changed due to padding.
Returns:
A tensor of rank 4+ representing
@ -727,79 +702,63 @@ class Conv3D(Conv):
(4, 7, 26, 26, 26, 2)
Arguments:
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 3 integers, specifying the
depth, height and width of the 3D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 3 integers,
specifying the strides of the convolution along each spatial
dimension.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
filters: Integer, the dimensionality of the output space (i.e. the number of
output filters in the convolution).
kernel_size: An integer or tuple/list of 3 integers, specifying the depth,
height and width of the 3D convolution window. Can be a single integer to
specify the same value for all spatial dimensions.
strides: An integer or tuple/list of 3 integers, specifying the strides of
the convolution along each spatial dimension. Can be a single integer to
specify the same value for all spatial dimensions. Specifying any stride
value != 1 is incompatible with specifying any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`batch_shape + (spatial_dim1, spatial_dim2, spatial_dim3, channels)`
while `channels_first` corresponds to inputs with shape
`batch_shape + (channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
dilation_rate: an integer or tuple/list of 3 integers, specifying
the dilation rate to use for dilated convolution.
Can be a single integer to specify the same value for
all spatial dimensions.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any stride value != 1.
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs. `channels_last` corresponds
to inputs with shape `batch_shape + (spatial_dim1, spatial_dim2,
spatial_dim3, channels)` while `channels_first` corresponds to inputs with
shape `batch_shape + (channels, spatial_dim1, spatial_dim2,
spatial_dim3)`. It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`. If you never set it, then it
will be "channels_last".
dilation_rate: an integer or tuple/list of 3 integers, specifying the
dilation rate to use for dilated convolution. Can be a single integer to
specify the same value for all spatial dimensions. Currently, specifying
any `dilation_rate` value != 1 is incompatible with specifying any stride
value != 1.
groups: A positive integer specifying the number of groups in which the
input is split along the channel axis. Each group is convolved
separately with `filters / groups` filters. The output is the
concatenation of all the `groups` results along the channel axis.
Input channels and `filters` must both be divisible by `groups`.
activation: Activation function to use.
If you don't specify anything, no activation is applied (
see `keras.activations`).
input is split along the channel axis. Each group is convolved separately
with `filters / groups` filters. The output is the concatenation of all
the `groups` results along the channel axis. Input channels and `filters`
must both be divisible by `groups`.
activation: Activation function to use. If you don't specify anything, no
activation is applied (see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix (
see `keras.initializers`).
bias_initializer: Initializer for the bias vector (
see `keras.initializers`).
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix (
see `keras.regularizers`).
bias_regularizer: Regularizer function applied to the bias vector (
see `keras.regularizers`).
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation") (
see `keras.regularizers`).
kernel_constraint: Constraint function applied to the kernel matrix (
see `keras.constraints`).
bias_constraint: Constraint function applied to the bias vector (
see `keras.constraints`).
kernel_initializer: Initializer for the `kernel` weights matrix (see
`keras.initializers`).
bias_initializer: Initializer for the bias vector (see
`keras.initializers`).
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix (see `keras.regularizers`).
bias_regularizer: Regularizer function applied to the bias vector (see
`keras.regularizers`).
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation") (see `keras.regularizers`).
kernel_constraint: Constraint function applied to the kernel matrix (see
`keras.constraints`).
bias_constraint: Constraint function applied to the bias vector (see
`keras.constraints`).
Input shape:
5+D tensor with shape:
`batch_shape + (channels, conv_dim1, conv_dim2, conv_dim3)` if
data_format='channels_first'
or 5+D tensor with shape:
`batch_shape + (conv_dim1, conv_dim2, conv_dim3, channels)` if
data_format='channels_last'.
5+D tensor with shape: `batch_shape + (channels, conv_dim1, conv_dim2,
conv_dim3)` if data_format='channels_first'
or 5+D tensor with shape: `batch_shape + (conv_dim1, conv_dim2, conv_dim3,
channels)` if data_format='channels_last'.
Output shape:
5+D tensor with shape:
`batch_shape + (filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if
data_format='channels_first'
or 5+D tensor with shape:
`batch_shape + (new_conv_dim1, new_conv_dim2, new_conv_dim3, filters)` if
data_format='channels_last'.
`new_conv_dim1`, `new_conv_dim2` and `new_conv_dim3` values might have
changed due to padding.
5+D tensor with shape: `batch_shape + (filters, new_conv_dim1,
new_conv_dim2, new_conv_dim3)` if data_format='channels_first'
or 5+D tensor with shape: `batch_shape + (new_conv_dim1, new_conv_dim2,
new_conv_dim3, filters)` if data_format='channels_last'. `new_conv_dim1`,
`new_conv_dim2` and `new_conv_dim3` values might have changed due to
padding.
Returns:
A tensor of rank 5+ representing

View File

@ -238,19 +238,6 @@ class Conv2DTest(keras_parameterized.TestCase):
self._run_test(kwargs, expected_output_shape)
self._run_test_extra_batch_dim(kwargs, expected_output_shape)
def test_conv2d_op_not_recreated_on_different_batch_shape(self):
layer = keras.layers.Conv2D(2, 3)
layer(np.ones((1, 28, 28, 2)))
# pylint: disable=protected-access
old_conv_op = layer._convolution_op
# Expand batch to rank-2 shape (5, 5)
layer(np.ones((5, 5, 28, 28, 2)))
self.assertEqual(old_conv_op, layer._convolution_op)
layer(np.ones((1, 30, 30, 2)))
# 'HW' changed, so the conv object is rebuilt
self.assertNotEqual(old_conv_op, layer._convolution_op)
# pylint: enable=protected-access
def test_conv2d_regularizers(self):
kwargs = {
'filters': 3,