Reduce Conv2D.__call__ overhead by 45%, Conv3D.__call__ overhead by 30%, and

Conv1D.__call__ overhead by 15%

Removes the indirection caused by using the nn_ops.Convolution class rather
than nn_ops.convolution.

PiperOrigin-RevId: 315920158
Change-Id: Ia5db49a36870b424b70d9d04962b84697d89bf55
This commit is contained in:
Thomas O'Malley 2020-06-11 09:59:47 -07:00 committed by TensorFlower Gardener
parent 52806d3849
commit 17b63987e5
3 changed files with 181 additions and 228 deletions

View File

@ -78,10 +78,17 @@ class ConvolutionNodeNameTest(xla_test.XLATestCase):
xla_names = _GetNodeNames(use_xla=True) xla_names = _GetNodeNames(use_xla=True)
no_xla_names = _GetNodeNames(use_xla=False) no_xla_names = _GetNodeNames(use_xla=False)
self.assertListEqual(
xla_names, # CPU path creates some additional nodes to handle dilations.
no_xla_names, # TODO(b/138804006): Remove this when CPU & GPU support dilations.
) filtered_no_xla_names = []
for name in no_xla_names:
if ("dilation_rate" in name or "filter_shape" in name or "stack" in name):
continue
else:
filtered_no_xla_names.append(name)
self.assertListEqual(xla_names, filtered_no_xla_names)
def testConv1DNodeNameMatch(self): def testConv1DNodeNameMatch(self):
input_sizes = [8, 16, 3] input_sizes = [8, 16, 3]

View File

@ -19,6 +19,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import six
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import activations from tensorflow.python.keras import activations
@ -125,6 +128,7 @@ class Conv(Layer):
bias_constraint=None, bias_constraint=None,
trainable=True, trainable=True,
name=None, name=None,
conv_op=None,
**kwargs): **kwargs):
super(Conv, self).__init__( super(Conv, self).__init__(
trainable=trainable, trainable=trainable,
@ -132,30 +136,22 @@ class Conv(Layer):
activity_regularizer=regularizers.get(activity_regularizer), activity_regularizer=regularizers.get(activity_regularizer),
**kwargs) **kwargs)
self.rank = rank self.rank = rank
if filters is not None and not isinstance(filters, int):
if isinstance(filters, float):
filters = int(filters) filters = int(filters)
self.filters = filters self.filters = filters
self.groups = groups or 1 self.groups = groups or 1
if filters is not None and filters % self.groups != 0:
raise ValueError(
'The number of filters must be evenly divisible by the number of '
'groups. Received: groups={}, filters={}'.format(groups, filters))
self.kernel_size = conv_utils.normalize_tuple( self.kernel_size = conv_utils.normalize_tuple(
kernel_size, rank, 'kernel_size') kernel_size, rank, 'kernel_size')
if not all(self.kernel_size):
raise ValueError('The argument `kernel_size` cannot contain 0(s). '
'Received: %s' % (kernel_size,))
self.strides = conv_utils.normalize_tuple(strides, rank, 'strides') self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
self.padding = conv_utils.normalize_padding(padding) self.padding = conv_utils.normalize_padding(padding)
if (self.padding == 'causal' and not isinstance(self,
(Conv1D, SeparableConv1D))):
raise ValueError('Causal padding is only supported for `Conv1D`'
'and ``SeparableConv1D`.')
self.data_format = conv_utils.normalize_data_format(data_format) self.data_format = conv_utils.normalize_data_format(data_format)
self.dilation_rate = conv_utils.normalize_tuple( self.dilation_rate = conv_utils.normalize_tuple(
dilation_rate, rank, 'dilation_rate') dilation_rate, rank, 'dilation_rate')
self.activation = activations.get(activation) self.activation = activations.get(activation)
self.use_bias = use_bias self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer) self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer) self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer) self.kernel_regularizer = regularizers.get(kernel_regularizer)
@ -164,6 +160,28 @@ class Conv(Layer):
self.bias_constraint = constraints.get(bias_constraint) self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=self.rank + 2) self.input_spec = InputSpec(min_ndim=self.rank + 2)
self._validate_init()
self._is_causal = self.padding == 'causal'
self._channels_first = self.data_format == 'channels_first'
self._tf_data_format = conv_utils.convert_data_format(
self.data_format, self.rank + 2)
def _validate_init(self):
if self.filters is not None and self.filters % self.groups != 0:
raise ValueError(
'The number of filters must be evenly divisible by the number of '
'groups. Received: groups={}, filters={}'.format(
self.groups, self.filters))
if not all(self.kernel_size):
raise ValueError('The argument `kernel_size` cannot contain 0(s). '
'Received: %s' % (self.kernel_size,))
if (self.padding == 'causal' and not isinstance(self,
(Conv1D, SeparableConv1D))):
raise ValueError('Causal padding is only supported for `Conv1D`'
'and `SeparableConv1D`.')
def build(self, input_shape): def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape) input_shape = tensor_shape.TensorShape(input_shape)
input_channel = self._get_input_channel(input_shape) input_channel = self._get_input_channel(input_shape)
@ -199,64 +217,53 @@ class Conv(Layer):
self.input_spec = InputSpec(min_ndim=self.rank + 2, self.input_spec = InputSpec(min_ndim=self.rank + 2,
axes={channel_axis: input_channel}) axes={channel_axis: input_channel})
self._build_conv_op_data_shape = input_shape[-(self.rank + 1):] # Convert Keras formats to TF native formats.
self._build_input_channel = input_channel if self.padding == 'causal':
self._padding_op = self._get_padding_op() tf_padding = 'VALID' # Causal padding handled in `call`.
self._conv_op_data_format = conv_utils.convert_data_format( elif isinstance(self.padding, six.string_types):
self.data_format, self.rank + 2) tf_padding = self.padding.upper()
self._convolution_op = nn_ops.Convolution( else:
input_shape, tf_padding = self.padding
filter_shape=self.kernel.shape, tf_dilations = self.dilation_rate
dilation_rate=self.dilation_rate, tf_strides = self.strides
strides=self.strides,
padding=self._padding_op, tf_op_name = self.__class__.__name__
data_format=self._conv_op_data_format, if tf_op_name == 'Conv1D':
num_spatial_dims=self.rank) tf_op_name = 'conv1d' # Backwards compat.
self._convolution_op = functools.partial(
nn_ops.convolution_v2,
strides=tf_strides,
padding=tf_padding,
dilations=tf_dilations,
data_format=self._tf_data_format,
name=tf_op_name)
self.built = True self.built = True
def call(self, inputs): def call(self, inputs):
if self._recreate_conv_op(inputs): if self._is_causal: # Apply causal padding to inputs for Conv1D.
self._convolution_op = nn_ops.Convolution(
inputs.shape,
filter_shape=self.kernel.shape,
dilation_rate=self.dilation_rate,
strides=self.strides,
padding=self._padding_op,
data_format=self._conv_op_data_format,
num_spatial_dims=self.rank)
self._build_conv_op_data_shape = inputs.shape[-(self.rank + 1):]
# Apply causal padding to inputs for Conv1D.
if self.padding == 'causal' and self.__class__.__name__ == 'Conv1D':
inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs)) inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs))
outputs = self._convolution_op(inputs, self.kernel) outputs = self._convolution_op(inputs, self.kernel)
if self.use_bias: if self.use_bias:
outputs_rank = outputs.shape.ndims output_rank = outputs.shape.rank
if self.data_format == 'channels_first': if self.rank == 1 and self._channels_first:
if self.rank == 1: # nn.bias_add does not accept a 1D input tensor.
# nn.bias_add does not accept a 1D input tensor. bias = array_ops.reshape(self.bias, (1, self.filters, 1))
bias = array_ops.reshape(self.bias, (1, self.filters, 1)) outputs += bias
outputs += bias
else:
if outputs_rank is not None and outputs_rank > 2 + self.rank:
# larger batch rank
outputs = nn_ops.squeeze_batch_dims(
outputs,
lambda o: nn.bias_add(o, self.bias, data_format='NCHW'),
inner_rank=self.rank + 1)
else:
outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
else: else:
if outputs_rank is not None and outputs_rank > 2 + self.rank: # Handle multiple batch dimensions.
# larger batch rank if output_rank is not None and output_rank > 2 + self.rank:
def _apply_fn(o):
return nn.bias_add(o, self.bias, data_format=self._tf_data_format)
outputs = nn_ops.squeeze_batch_dims( outputs = nn_ops.squeeze_batch_dims(
outputs, outputs, _apply_fn, inner_rank=self.rank + 1)
lambda o: nn.bias_add(o, self.bias, data_format='NHWC'),
inner_rank=self.rank + 1)
else: else:
outputs = nn.bias_add(outputs, self.bias, data_format='NHWC') outputs = nn.bias_add(
outputs, self.bias, data_format=self._tf_data_format)
if self.activation is not None: if self.activation is not None:
return self.activation(outputs) return self.activation(outputs)
@ -286,6 +293,9 @@ class Conv(Layer):
input_shape[:batch_rank] + [self.filters] + input_shape[:batch_rank] + [self.filters] +
self._spatial_output_shape(input_shape[batch_rank + 1:])) self._spatial_output_shape(input_shape[batch_rank + 1:]))
def _recreate_conv_op(self, inputs): # pylint: disable=unused-argument
return False
def get_config(self): def get_config(self):
config = { config = {
'filters': 'filters':
@ -359,27 +369,6 @@ class Conv(Layer):
op_padding = op_padding.upper() op_padding = op_padding.upper()
return op_padding return op_padding
def _recreate_conv_op(self, inputs):
"""Recreate conv_op if necessary.
Check if the input_shape in call() is different from that in build().
If the most-specific input shape describing the build and call shapes is not
equal to the shape we currently built with, then we need to rebuild the
_convolution_op to avoid incorrect behavior.
Args:
inputs: The input data to call() method.
Returns:
`True` or `False` to indicate whether to recreate the conv_op.
"""
call_data_shape = inputs.shape[-(self.rank + 1):]
# If the most specific compatible shape between _build_data_shape and
# call_data_shape is not _build_data_shape then we must re-build.
return (self._build_conv_op_data_shape
!= self._build_conv_op_data_shape.most_specific_compatible_shape(
call_data_shape))
@keras_export('keras.layers.Conv1D', 'keras.layers.Convolution1D') @keras_export('keras.layers.Conv1D', 'keras.layers.Convolution1D')
class Conv1D(Conv): class Conv1D(Conv):
@ -572,74 +561,60 @@ class Conv2D(Conv):
Arguments: Arguments:
filters: Integer, the dimensionality of the output space filters: Integer, the dimensionality of the output space (i.e. the number of
(i.e. the number of output filters in the convolution). output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the kernel_size: An integer or tuple/list of 2 integers, specifying the height
height and width of the 2D convolution window. and width of the 2D convolution window. Can be a single integer to specify
Can be a single integer to specify the same value for the same value for all spatial dimensions.
all spatial dimensions. strides: An integer or tuple/list of 2 integers, specifying the strides of
strides: An integer or tuple/list of 2 integers, the convolution along the height and width. Can be a single integer to
specifying the strides of the convolution along the height and width. specify the same value for all spatial dimensions. Specifying any stride
Can be a single integer to specify the same value for value != 1 is incompatible with specifying any `dilation_rate` value != 1.
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive). padding: one of `"valid"` or `"same"` (case-insensitive).
data_format: A string, data_format: A string, one of `channels_last` (default) or `channels_first`.
one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds
The ordering of the dimensions in the inputs. to inputs with shape `(batch_size, height, width, channels)` while
`channels_last` corresponds to inputs with shape `channels_first` corresponds to inputs with shape `(batch_size, channels,
`(batch_size, height, width, channels)` while `channels_first` height, width)`. It defaults to the `image_data_format` value found in
corresponds to inputs with shape your Keras config file at `~/.keras/keras.json`. If you never set it, then
`(batch_size, channels, height, width)`. it will be `channels_last`.
It defaults to the `image_data_format` value found in your dilation_rate: an integer or tuple/list of 2 integers, specifying the
Keras config file at `~/.keras/keras.json`. dilation rate to use for dilated convolution. Can be a single integer to
If you never set it, then it will be `channels_last`. specify the same value for all spatial dimensions. Currently, specifying
dilation_rate: an integer or tuple/list of 2 integers, specifying any `dilation_rate` value != 1 is incompatible with specifying any stride
the dilation rate to use for dilated convolution. value != 1.
Can be a single integer to specify the same value for
all spatial dimensions.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any stride value != 1.
groups: A positive integer specifying the number of groups in which the groups: A positive integer specifying the number of groups in which the
input is split along the channel axis. Each group is convolved input is split along the channel axis. Each group is convolved separately
separately with `filters / groups` filters. The output is the with `filters / groups` filters. The output is the concatenation of all
concatenation of all the `groups` results along the channel axis. the `groups` results along the channel axis. Input channels and `filters`
Input channels and `filters` must both be divisible by `groups`. must both be divisible by `groups`.
activation: Activation function to use. activation: Activation function to use. If you don't specify anything, no
If you don't specify anything, no activation is applied ( activation is applied (see `keras.activations`).
see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias vector. use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix ( kernel_initializer: Initializer for the `kernel` weights matrix (see
see `keras.initializers`). `keras.initializers`).
bias_initializer: Initializer for the bias vector ( bias_initializer: Initializer for the bias vector (see
see `keras.initializers`). `keras.initializers`).
kernel_regularizer: Regularizer function applied to kernel_regularizer: Regularizer function applied to the `kernel` weights
the `kernel` weights matrix (see `keras.regularizers`). matrix (see `keras.regularizers`).
bias_regularizer: Regularizer function applied to the bias vector ( bias_regularizer: Regularizer function applied to the bias vector (see
see `keras.regularizers`). `keras.regularizers`).
activity_regularizer: Regularizer function applied to activity_regularizer: Regularizer function applied to the output of the
the output of the layer (its "activation") ( layer (its "activation") (see `keras.regularizers`).
see `keras.regularizers`). kernel_constraint: Constraint function applied to the kernel matrix (see
kernel_constraint: Constraint function applied to the kernel matrix ( `keras.constraints`).
see `keras.constraints`). bias_constraint: Constraint function applied to the bias vector (see
bias_constraint: Constraint function applied to the bias vector ( `keras.constraints`).
see `keras.constraints`).
Input shape: Input shape:
4+D tensor with shape: 4+D tensor with shape: `batch_shape + (channels, rows, cols)` if
`batch_shape + (channels, rows, cols)` if `data_format='channels_first'` `data_format='channels_first'`
or 4+D tensor with shape: or 4+D tensor with shape: `batch_shape + (rows, cols, channels)` if
`batch_shape + (rows, cols, channels)` if `data_format='channels_last'`. `data_format='channels_last'`.
Output shape: Output shape:
4+D tensor with shape: 4+D tensor with shape: `batch_shape + (filters, new_rows, new_cols)` if
`batch_shape + (filters, new_rows, new_cols)` if `data_format='channels_first'` or 4+D tensor with shape: `batch_shape +
`data_format='channels_first'` or 4+D tensor with shape: (new_rows, new_cols, filters)` if `data_format='channels_last'`. `rows`
`batch_shape + (new_rows, new_cols, filters)` if and `cols` values might have changed due to padding.
`data_format='channels_last'`.
`rows` and `cols` values might have changed due to padding.
Returns: Returns:
A tensor of rank 4+ representing A tensor of rank 4+ representing
@ -727,79 +702,63 @@ class Conv3D(Conv):
(4, 7, 26, 26, 26, 2) (4, 7, 26, 26, 26, 2)
Arguments: Arguments:
filters: Integer, the dimensionality of the output space filters: Integer, the dimensionality of the output space (i.e. the number of
(i.e. the number of output filters in the convolution). output filters in the convolution).
kernel_size: An integer or tuple/list of 3 integers, specifying the kernel_size: An integer or tuple/list of 3 integers, specifying the depth,
depth, height and width of the 3D convolution window. height and width of the 3D convolution window. Can be a single integer to
Can be a single integer to specify the same value for specify the same value for all spatial dimensions.
all spatial dimensions. strides: An integer or tuple/list of 3 integers, specifying the strides of
strides: An integer or tuple/list of 3 integers, the convolution along each spatial dimension. Can be a single integer to
specifying the strides of the convolution along each spatial specify the same value for all spatial dimensions. Specifying any stride
dimension. value != 1 is incompatible with specifying any `dilation_rate` value != 1.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive). padding: one of `"valid"` or `"same"` (case-insensitive).
data_format: A string, data_format: A string, one of `channels_last` (default) or `channels_first`.
one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds
The ordering of the dimensions in the inputs. to inputs with shape `batch_shape + (spatial_dim1, spatial_dim2,
`channels_last` corresponds to inputs with shape spatial_dim3, channels)` while `channels_first` corresponds to inputs with
`batch_shape + (spatial_dim1, spatial_dim2, spatial_dim3, channels)` shape `batch_shape + (channels, spatial_dim1, spatial_dim2,
while `channels_first` corresponds to inputs with shape spatial_dim3)`. It defaults to the `image_data_format` value found in your
`batch_shape + (channels, spatial_dim1, spatial_dim2, spatial_dim3)`. Keras config file at `~/.keras/keras.json`. If you never set it, then it
It defaults to the `image_data_format` value found in your will be "channels_last".
Keras config file at `~/.keras/keras.json`. dilation_rate: an integer or tuple/list of 3 integers, specifying the
If you never set it, then it will be "channels_last". dilation rate to use for dilated convolution. Can be a single integer to
dilation_rate: an integer or tuple/list of 3 integers, specifying specify the same value for all spatial dimensions. Currently, specifying
the dilation rate to use for dilated convolution. any `dilation_rate` value != 1 is incompatible with specifying any stride
Can be a single integer to specify the same value for value != 1.
all spatial dimensions.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any stride value != 1.
groups: A positive integer specifying the number of groups in which the groups: A positive integer specifying the number of groups in which the
input is split along the channel axis. Each group is convolved input is split along the channel axis. Each group is convolved separately
separately with `filters / groups` filters. The output is the with `filters / groups` filters. The output is the concatenation of all
concatenation of all the `groups` results along the channel axis. the `groups` results along the channel axis. Input channels and `filters`
Input channels and `filters` must both be divisible by `groups`. must both be divisible by `groups`.
activation: Activation function to use. activation: Activation function to use. If you don't specify anything, no
If you don't specify anything, no activation is applied ( activation is applied (see `keras.activations`).
see `keras.activations`).
use_bias: Boolean, whether the layer uses a bias vector. use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix ( kernel_initializer: Initializer for the `kernel` weights matrix (see
see `keras.initializers`). `keras.initializers`).
bias_initializer: Initializer for the bias vector ( bias_initializer: Initializer for the bias vector (see
see `keras.initializers`). `keras.initializers`).
kernel_regularizer: Regularizer function applied to kernel_regularizer: Regularizer function applied to the `kernel` weights
the `kernel` weights matrix ( matrix (see `keras.regularizers`).
see `keras.regularizers`). bias_regularizer: Regularizer function applied to the bias vector (see
bias_regularizer: Regularizer function applied to the bias vector ( `keras.regularizers`).
see `keras.regularizers`). activity_regularizer: Regularizer function applied to the output of the
activity_regularizer: Regularizer function applied to layer (its "activation") (see `keras.regularizers`).
the output of the layer (its "activation") ( kernel_constraint: Constraint function applied to the kernel matrix (see
see `keras.regularizers`). `keras.constraints`).
kernel_constraint: Constraint function applied to the kernel matrix ( bias_constraint: Constraint function applied to the bias vector (see
see `keras.constraints`). `keras.constraints`).
bias_constraint: Constraint function applied to the bias vector (
see `keras.constraints`).
Input shape: Input shape:
5+D tensor with shape: 5+D tensor with shape: `batch_shape + (channels, conv_dim1, conv_dim2,
`batch_shape + (channels, conv_dim1, conv_dim2, conv_dim3)` if conv_dim3)` if data_format='channels_first'
data_format='channels_first' or 5+D tensor with shape: `batch_shape + (conv_dim1, conv_dim2, conv_dim3,
or 5+D tensor with shape: channels)` if data_format='channels_last'.
`batch_shape + (conv_dim1, conv_dim2, conv_dim3, channels)` if
data_format='channels_last'.
Output shape: Output shape:
5+D tensor with shape: 5+D tensor with shape: `batch_shape + (filters, new_conv_dim1,
`batch_shape + (filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if new_conv_dim2, new_conv_dim3)` if data_format='channels_first'
data_format='channels_first' or 5+D tensor with shape: `batch_shape + (new_conv_dim1, new_conv_dim2,
or 5+D tensor with shape: new_conv_dim3, filters)` if data_format='channels_last'. `new_conv_dim1`,
`batch_shape + (new_conv_dim1, new_conv_dim2, new_conv_dim3, filters)` if `new_conv_dim2` and `new_conv_dim3` values might have changed due to
data_format='channels_last'. padding.
`new_conv_dim1`, `new_conv_dim2` and `new_conv_dim3` values might have
changed due to padding.
Returns: Returns:
A tensor of rank 5+ representing A tensor of rank 5+ representing

View File

@ -238,19 +238,6 @@ class Conv2DTest(keras_parameterized.TestCase):
self._run_test(kwargs, expected_output_shape) self._run_test(kwargs, expected_output_shape)
self._run_test_extra_batch_dim(kwargs, expected_output_shape) self._run_test_extra_batch_dim(kwargs, expected_output_shape)
def test_conv2d_op_not_recreated_on_different_batch_shape(self):
layer = keras.layers.Conv2D(2, 3)
layer(np.ones((1, 28, 28, 2)))
# pylint: disable=protected-access
old_conv_op = layer._convolution_op
# Expand batch to rank-2 shape (5, 5)
layer(np.ones((5, 5, 28, 28, 2)))
self.assertEqual(old_conv_op, layer._convolution_op)
layer(np.ones((1, 30, 30, 2)))
# 'HW' changed, so the conv object is rebuilt
self.assertNotEqual(old_conv_op, layer._convolution_op)
# pylint: enable=protected-access
def test_conv2d_regularizers(self): def test_conv2d_regularizers(self):
kwargs = { kwargs = {
'filters': 3, 'filters': 3,