[TF] Add support for higher rank batch shape to keras.layers.Conv{,2D}.

The support is added to the base Conv layer, but the underlying 1D & 3D NN ops don't
yet support it.  But that's very easy to add now; so if this change
goes in, I can add support for nn_ops conv1d, conv3d, and keras Conv1D and
Conv3D layers all together.

PiperOrigin-RevId: 314331499
Change-Id: Id1f723993881c206d006711afd3fc7921459df11
This commit is contained in:
Eugene Brevdo 2020-06-02 08:04:35 -07:00 committed by TensorFlower Gardener
parent f1dd2e7a45
commit 4b123fc001
3 changed files with 81 additions and 36 deletions

View File

@ -162,7 +162,7 @@ class Conv(Layer):
self.bias_regularizer = regularizers.get(bias_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint) self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint) self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=self.rank + 2) self.input_spec = InputSpec(min_ndim=self.rank + 2)
def build(self, input_shape): def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape) input_shape = tensor_shape.TensorShape(input_shape)
@ -196,7 +196,7 @@ class Conv(Layer):
else: else:
self.bias = None self.bias = None
channel_axis = self._get_channel_axis() channel_axis = self._get_channel_axis()
self.input_spec = InputSpec(ndim=self.rank + 2, self.input_spec = InputSpec(min_ndim=self.rank + 2,
axes={channel_axis: input_channel}) axes={channel_axis: input_channel})
self._build_conv_op_input_shape = input_shape self._build_conv_op_input_shape = input_shape
@ -210,18 +210,20 @@ class Conv(Layer):
dilation_rate=self.dilation_rate, dilation_rate=self.dilation_rate,
strides=self.strides, strides=self.strides,
padding=self._padding_op, padding=self._padding_op,
data_format=self._conv_op_data_format) data_format=self._conv_op_data_format,
num_spatial_dims=self.rank)
self.built = True self.built = True
def call(self, inputs): def call(self, inputs):
if self._recreate_conv_op(inputs): if self._recreate_conv_op(inputs):
self._convolution_op = nn_ops.Convolution( self._convolution_op = nn_ops.Convolution(
inputs.get_shape(), inputs.shape,
filter_shape=self.kernel.shape, filter_shape=self.kernel.shape,
dilation_rate=self.dilation_rate, dilation_rate=self.dilation_rate,
strides=self.strides, strides=self.strides,
padding=self._padding_op, padding=self._padding_op,
data_format=self._conv_op_data_format) data_format=self._conv_op_data_format,
num_spatial_dims=self.rank)
self._build_conv_op_input_shape = inputs.get_shape() self._build_conv_op_input_shape = inputs.get_shape()
# Apply causal padding to inputs for Conv1D. # Apply causal padding to inputs for Conv1D.
@ -231,15 +233,30 @@ class Conv(Layer):
outputs = self._convolution_op(inputs, self.kernel) outputs = self._convolution_op(inputs, self.kernel)
if self.use_bias: if self.use_bias:
outputs_rank = outputs.shape.ndims
if self.data_format == 'channels_first': if self.data_format == 'channels_first':
if self.rank == 1: if self.rank == 1:
# nn.bias_add does not accept a 1D input tensor. # nn.bias_add does not accept a 1D input tensor.
bias = array_ops.reshape(self.bias, (1, self.filters, 1)) bias = array_ops.reshape(self.bias, (1, self.filters, 1))
outputs += bias outputs += bias
else: else:
outputs = nn.bias_add(outputs, self.bias, data_format='NCHW') if outputs_rank is not None and outputs_rank > 2 + self.rank:
# larger batch rank
outputs = nn_ops.squeeze_batch_dims(
outputs,
lambda o: nn.bias_add(o, self.bias, data_format='NCHW'),
inner_rank=self.rank + 1)
else:
outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
else: else:
outputs = nn.bias_add(outputs, self.bias, data_format='NHWC') if outputs_rank is not None and outputs_rank > 2 + self.rank:
# larger batch rank
outputs = nn_ops.squeeze_batch_dims(
outputs,
lambda o: nn.bias_add(o, self.bias, data_format='NHWC'),
inner_rank=self.rank + 1)
else:
outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
if self.activation is not None: if self.activation is not None:
return self.activation(outputs) return self.activation(outputs)
@ -258,14 +275,16 @@ class Conv(Layer):
def compute_output_shape(self, input_shape): def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list() input_shape = tensor_shape.TensorShape(input_shape).as_list()
batch_rank = len(input_shape) - self.rank - 1
if self.data_format == 'channels_last': if self.data_format == 'channels_last':
return tensor_shape.TensorShape( return tensor_shape.TensorShape(
[input_shape[0]] + self._spatial_output_shape(input_shape[1:-1]) + input_shape[:batch_rank]
[self.filters]) + self._spatial_output_shape(input_shape[batch_rank:-1])
+ [self.filters])
else: else:
return tensor_shape.TensorShape( return tensor_shape.TensorShape(
[input_shape[0], self.filters] + input_shape[:batch_rank] + [self.filters] +
self._spatial_output_shape(input_shape[2:])) self._spatial_output_shape(input_shape[batch_rank + 1:]))
def get_config(self): def get_config(self):
config = { config = {
@ -316,7 +335,7 @@ class Conv(Layer):
def _get_channel_axis(self): def _get_channel_axis(self):
if self.data_format == 'channels_first': if self.data_format == 'channels_first':
return 1 return -1 - self.rank
else: else:
return -1 return -1
@ -350,7 +369,7 @@ class Conv(Layer):
Returns: Returns:
`True` or `False` to indicate whether to recreate the conv_op. `True` or `False` to indicate whether to recreate the conv_op.
""" """
call_input_shape = inputs.get_shape() call_input_shape = inputs.shape
# If the most specific compatible shape between _build_input_shape and # If the most specific compatible shape between _build_input_shape and
# call_input_shape is not _build_input_shape then we must re-build. # call_input_shape is not _build_input_shape then we must re-build.
return self._build_conv_op_input_shape.most_specific_compatible_shape( return self._build_conv_op_input_shape.most_specific_compatible_shape(
@ -381,7 +400,7 @@ class Conv1D(Conv):
>>> input_shape = (4, 10, 128) >>> input_shape = (4, 10, 128)
>>> x = tf.random.normal(input_shape) >>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv1D( >>> y = tf.keras.layers.Conv1D(
... 32, 3, activation='relu',input_shape=input_shape)(x) ... 32, 3, activation='relu',input_shape=input_shape[1:])(x)
>>> print(y.shape) >>> print(y.shape)
(4, 8, 32) (4, 8, 32)
@ -508,7 +527,7 @@ class Conv2D(Conv):
>>> input_shape = (4, 28, 28, 3) >>> input_shape = (4, 28, 28, 3)
>>> x = tf.random.normal(input_shape) >>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv2D( >>> y = tf.keras.layers.Conv2D(
... 2, 3, activation='relu', input_shape=input_shape)(x) ... 2, 3, activation='relu', input_shape=input_shape[1:])(x)
>>> print(y.shape) >>> print(y.shape)
(4, 26, 26, 2) (4, 26, 26, 2)
@ -516,7 +535,7 @@ class Conv2D(Conv):
>>> input_shape = (4, 28, 28, 3) >>> input_shape = (4, 28, 28, 3)
>>> x = tf.random.normal(input_shape) >>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv2D( >>> y = tf.keras.layers.Conv2D(
... 2, 3, activation='relu', dilation_rate=2, input_shape=input_shape)(x) ... 2, 3, activation='relu', dilation_rate=2, input_shape=input_shape[1:])(x)
>>> print(y.shape) >>> print(y.shape)
(4, 24, 24, 2) (4, 24, 24, 2)
@ -524,10 +543,18 @@ class Conv2D(Conv):
>>> input_shape = (4, 28, 28, 3) >>> input_shape = (4, 28, 28, 3)
>>> x = tf.random.normal(input_shape) >>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv2D( >>> y = tf.keras.layers.Conv2D(
... 2, 3, activation='relu', padding="same", input_shape=input_shape)(x) ... 2, 3, activation='relu', padding="same", input_shape=input_shape[1:])(x)
>>> print(y.shape) >>> print(y.shape)
(4, 28, 28, 2) (4, 28, 28, 2)
>>> # With extended batch shape [4, 7]:
>>> input_shape = (4, 7, 28, 28, 3)
>>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv2D(
... 2, 3, activation='relu', input_shape=input_shape[2:])(x)
>>> print(y.shape)
(4, 7, 26, 26, 2)
Arguments: Arguments:
filters: Integer, the dimensionality of the output space filters: Integer, the dimensionality of the output space
@ -552,7 +579,7 @@ class Conv2D(Conv):
`(batch_size, channels, height, width)`. `(batch_size, channels, height, width)`.
It defaults to the `image_data_format` value found in your It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`. Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last". If you never set it, then it will be `channels_last`.
dilation_rate: an integer or tuple/list of 2 integers, specifying dilation_rate: an integer or tuple/list of 2 integers, specifying
the dilation rate to use for dilated convolution. the dilation rate to use for dilated convolution.
Can be a single integer to specify the same value for Can be a single integer to specify the same value for
@ -585,25 +612,27 @@ class Conv2D(Conv):
see `keras.constraints`). see `keras.constraints`).
Input shape: Input shape:
4D tensor with shape: 4+D tensor with shape:
`(batch_size, channels, rows, cols)` if data_format='channels_first' `batch_shape + (channels, rows, cols)` if `data_format='channels_first'`
or 4D tensor with shape: or 4+D tensor with shape:
`(batch_size, rows, cols, channels)` if data_format='channels_last'. `batch_shape + (rows, cols, channels)` if `data_format='channels_last'`.
Output shape: Output shape:
4D tensor with shape: 4+D tensor with shape:
`(batch_size, filters, new_rows, new_cols)` if data_format='channels_first' `batch_shape + (filters, new_rows, new_cols)` if
or 4D tensor with shape: `data_format='channels_first'` or 4+D tensor with shape:
`(batch_size, new_rows, new_cols, filters)` if data_format='channels_last'. `batch_shape + (new_rows, new_cols, filters)` if
`data_format='channels_last'`.
`rows` and `cols` values might have changed due to padding. `rows` and `cols` values might have changed due to padding.
Returns: Returns:
A tensor of rank 4 representing A tensor of rank 4+ representing
`activation(conv2d(inputs, kernel) + bias)`. `activation(conv2d(inputs, kernel) + bias)`.
Raises: Raises:
ValueError: if `padding` is "causal". ValueError: if `padding` is `"causal"`.
ValueError: when both `strides` > 1 and `dilation_rate` > 1. ValueError: when both `strides > 1` and `dilation_rate > 1`.
""" """
def __init__(self, def __init__(self,
@ -669,7 +698,7 @@ class Conv3D(Conv):
>>> input_shape =(4, 28, 28, 28, 1) >>> input_shape =(4, 28, 28, 28, 1)
>>> x = tf.random.normal(input_shape) >>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv3D( >>> y = tf.keras.layers.Conv3D(
... 2, 3, activation='relu', input_shape=input_shape)(x) ... 2, 3, activation='relu', input_shape=input_shape[1:])(x)
>>> print(y.shape) >>> print(y.shape)
(4, 26, 26, 26, 2) (4, 26, 26, 26, 2)
@ -1198,8 +1227,8 @@ class Conv2DTranspose(Conv2D):
def build(self, input_shape): def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape) input_shape = tensor_shape.TensorShape(input_shape)
if len(input_shape) != 4: if len(input_shape) != 4:
raise ValueError('Inputs should have rank 4. Received input shape: ' + raise ValueError('Inputs should have rank 4. Received input '
str(input_shape)) 'shape: ' + str(input_shape))
channel_axis = self._get_channel_axis() channel_axis = self._get_channel_axis()
if input_shape.dims[channel_axis].value is None: if input_shape.dims[channel_axis].value is None:
raise ValueError('The channel dimension of the inputs ' raise ValueError('The channel dimension of the inputs '

View File

@ -171,6 +171,21 @@ class Conv2DTest(keras_parameterized.TestCase):
input_shape=(num_samples, num_row, num_col, stack_size), input_shape=(num_samples, num_row, num_col, stack_size),
expected_output_shape=expected_output_shape) expected_output_shape=expected_output_shape)
def _run_test_extra_batch_dim(self, kwargs, expected_output_shape):
batch_shape = (2, 11)
stack_size = 3
num_row = 7
num_col = 6
with self.cached_session(use_gpu=True):
if expected_output_shape is not None:
expected_output_shape = (None,) + expected_output_shape
testing_utils.layer_test(
keras.layers.Conv2D,
kwargs=kwargs,
input_shape=batch_shape + (num_row, num_col, stack_size),
expected_output_shape=expected_output_shape)
@parameterized.named_parameters( @parameterized.named_parameters(
('padding_valid', { ('padding_valid', {
'padding': 'valid' 'padding': 'valid'
@ -205,6 +220,7 @@ class Conv2DTest(keras_parameterized.TestCase):
kwargs['kernel_size'] = (3, 3) kwargs['kernel_size'] = (3, 3)
if not requires_gpu or test.is_gpu_available(cuda_only=True): if not requires_gpu or test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, expected_output_shape) self._run_test(kwargs, expected_output_shape)
self._run_test_extra_batch_dim(kwargs, expected_output_shape)
def test_conv2d_regularizers(self): def test_conv2d_regularizers(self):
kwargs = { kwargs = {

View File

@ -254,7 +254,7 @@ class _NonAtrousConvolution(object):
name=self.name) name=self.name)
def _squeeze_batch_dims(inp, op, inner_rank, name): def squeeze_batch_dims(inp, op, inner_rank, name=None):
"""Returns `unsqueeze_batch(op(squeeze_batch(inp)))`. """Returns `unsqueeze_batch(op(squeeze_batch(inp)))`.
Where `squeeze_batch` reshapes `inp` to shape Where `squeeze_batch` reshapes `inp` to shape
@ -272,7 +272,7 @@ def _squeeze_batch_dims(inp, op, inner_rank, name):
Returns: Returns:
`unsqueeze_batch_op(squeeze_batch(inp))`. `unsqueeze_batch_op(squeeze_batch(inp))`.
""" """
with ops.name_scope(name, "Convolution", [inp]): with ops.name_scope(name, "squeeze_batch_dims", [inp]):
inp = ops.convert_to_tensor(inp, name="input") inp = ops.convert_to_tensor(inp, name="input")
shape = inp.shape shape = inp.shape
@ -2224,7 +2224,7 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value
data_format=data_format, data_format=data_format,
dilations=dilations, dilations=dilations,
name=name) name=name)
return _squeeze_batch_dims( return squeeze_batch_dims(
input, input,
functools.partial( functools.partial(
gen_nn_ops.conv2d, gen_nn_ops.conv2d,
@ -2543,7 +2543,7 @@ def _conv2d_expanded_batch(
data_format=data_format, data_format=data_format,
dilations=dilations, dilations=dilations,
name=name) name=name)
return _squeeze_batch_dims( return squeeze_batch_dims(
input, input,
functools.partial( functools.partial(
gen_nn_ops.conv2d, gen_nn_ops.conv2d,