[TF] Add support for higher rank batch shape to keras.layers.Conv{,2D}.

The support is added to the base Conv layer, but the underlying 1D & 3D NN ops don't
yet support it.  But that's very easy to add now; so if this change
goes in, I can add support for nn_ops conv1d, conv3d, and keras Conv1D and
Conv3D layers all together.

PiperOrigin-RevId: 314331499
Change-Id: Id1f723993881c206d006711afd3fc7921459df11
This commit is contained in:
Eugene Brevdo 2020-06-02 08:04:35 -07:00 committed by TensorFlower Gardener
parent f1dd2e7a45
commit 4b123fc001
3 changed files with 81 additions and 36 deletions

View File

@ -162,7 +162,7 @@ class Conv(Layer):
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=self.rank + 2)
self.input_spec = InputSpec(min_ndim=self.rank + 2)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
@ -196,7 +196,7 @@ class Conv(Layer):
else:
self.bias = None
channel_axis = self._get_channel_axis()
self.input_spec = InputSpec(ndim=self.rank + 2,
self.input_spec = InputSpec(min_ndim=self.rank + 2,
axes={channel_axis: input_channel})
self._build_conv_op_input_shape = input_shape
@ -210,18 +210,20 @@ class Conv(Layer):
dilation_rate=self.dilation_rate,
strides=self.strides,
padding=self._padding_op,
data_format=self._conv_op_data_format)
data_format=self._conv_op_data_format,
num_spatial_dims=self.rank)
self.built = True
def call(self, inputs):
if self._recreate_conv_op(inputs):
self._convolution_op = nn_ops.Convolution(
inputs.get_shape(),
inputs.shape,
filter_shape=self.kernel.shape,
dilation_rate=self.dilation_rate,
strides=self.strides,
padding=self._padding_op,
data_format=self._conv_op_data_format)
data_format=self._conv_op_data_format,
num_spatial_dims=self.rank)
self._build_conv_op_input_shape = inputs.get_shape()
# Apply causal padding to inputs for Conv1D.
@ -231,13 +233,28 @@ class Conv(Layer):
outputs = self._convolution_op(inputs, self.kernel)
if self.use_bias:
outputs_rank = outputs.shape.ndims
if self.data_format == 'channels_first':
if self.rank == 1:
# nn.bias_add does not accept a 1D input tensor.
bias = array_ops.reshape(self.bias, (1, self.filters, 1))
outputs += bias
else:
if outputs_rank is not None and outputs_rank > 2 + self.rank:
# larger batch rank
outputs = nn_ops.squeeze_batch_dims(
outputs,
lambda o: nn.bias_add(o, self.bias, data_format='NCHW'),
inner_rank=self.rank + 1)
else:
outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
else:
if outputs_rank is not None and outputs_rank > 2 + self.rank:
# larger batch rank
outputs = nn_ops.squeeze_batch_dims(
outputs,
lambda o: nn.bias_add(o, self.bias, data_format='NHWC'),
inner_rank=self.rank + 1)
else:
outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
@ -258,14 +275,16 @@ class Conv(Layer):
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
batch_rank = len(input_shape) - self.rank - 1
if self.data_format == 'channels_last':
return tensor_shape.TensorShape(
[input_shape[0]] + self._spatial_output_shape(input_shape[1:-1]) +
[self.filters])
input_shape[:batch_rank]
+ self._spatial_output_shape(input_shape[batch_rank:-1])
+ [self.filters])
else:
return tensor_shape.TensorShape(
[input_shape[0], self.filters] +
self._spatial_output_shape(input_shape[2:]))
input_shape[:batch_rank] + [self.filters] +
self._spatial_output_shape(input_shape[batch_rank + 1:]))
def get_config(self):
config = {
@ -316,7 +335,7 @@ class Conv(Layer):
def _get_channel_axis(self):
if self.data_format == 'channels_first':
return 1
return -1 - self.rank
else:
return -1
@ -350,7 +369,7 @@ class Conv(Layer):
Returns:
`True` or `False` to indicate whether to recreate the conv_op.
"""
call_input_shape = inputs.get_shape()
call_input_shape = inputs.shape
# If the most specific compatible shape between _build_input_shape and
# call_input_shape is not _build_input_shape then we must re-build.
return self._build_conv_op_input_shape.most_specific_compatible_shape(
@ -381,7 +400,7 @@ class Conv1D(Conv):
>>> input_shape = (4, 10, 128)
>>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv1D(
... 32, 3, activation='relu',input_shape=input_shape)(x)
... 32, 3, activation='relu',input_shape=input_shape[1:])(x)
>>> print(y.shape)
(4, 8, 32)
@ -508,7 +527,7 @@ class Conv2D(Conv):
>>> input_shape = (4, 28, 28, 3)
>>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv2D(
... 2, 3, activation='relu', input_shape=input_shape)(x)
... 2, 3, activation='relu', input_shape=input_shape[1:])(x)
>>> print(y.shape)
(4, 26, 26, 2)
@ -516,7 +535,7 @@ class Conv2D(Conv):
>>> input_shape = (4, 28, 28, 3)
>>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv2D(
... 2, 3, activation='relu', dilation_rate=2, input_shape=input_shape)(x)
... 2, 3, activation='relu', dilation_rate=2, input_shape=input_shape[1:])(x)
>>> print(y.shape)
(4, 24, 24, 2)
@ -524,10 +543,18 @@ class Conv2D(Conv):
>>> input_shape = (4, 28, 28, 3)
>>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv2D(
... 2, 3, activation='relu', padding="same", input_shape=input_shape)(x)
... 2, 3, activation='relu', padding="same", input_shape=input_shape[1:])(x)
>>> print(y.shape)
(4, 28, 28, 2)
>>> # With extended batch shape [4, 7]:
>>> input_shape = (4, 7, 28, 28, 3)
>>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv2D(
... 2, 3, activation='relu', input_shape=input_shape[2:])(x)
>>> print(y.shape)
(4, 7, 26, 26, 2)
Arguments:
filters: Integer, the dimensionality of the output space
@ -552,7 +579,7 @@ class Conv2D(Conv):
`(batch_size, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
If you never set it, then it will be `channels_last`.
dilation_rate: an integer or tuple/list of 2 integers, specifying
the dilation rate to use for dilated convolution.
Can be a single integer to specify the same value for
@ -585,25 +612,27 @@ class Conv2D(Conv):
see `keras.constraints`).
Input shape:
4D tensor with shape:
`(batch_size, channels, rows, cols)` if data_format='channels_first'
or 4D tensor with shape:
`(batch_size, rows, cols, channels)` if data_format='channels_last'.
4+D tensor with shape:
`batch_shape + (channels, rows, cols)` if `data_format='channels_first'`
or 4+D tensor with shape:
`batch_shape + (rows, cols, channels)` if `data_format='channels_last'`.
Output shape:
4D tensor with shape:
`(batch_size, filters, new_rows, new_cols)` if data_format='channels_first'
or 4D tensor with shape:
`(batch_size, new_rows, new_cols, filters)` if data_format='channels_last'.
4+D tensor with shape:
`batch_shape + (filters, new_rows, new_cols)` if
`data_format='channels_first'` or 4+D tensor with shape:
`batch_shape + (new_rows, new_cols, filters)` if
`data_format='channels_last'`.
`rows` and `cols` values might have changed due to padding.
Returns:
A tensor of rank 4 representing
A tensor of rank 4+ representing
`activation(conv2d(inputs, kernel) + bias)`.
Raises:
ValueError: if `padding` is "causal".
ValueError: when both `strides` > 1 and `dilation_rate` > 1.
ValueError: if `padding` is `"causal"`.
ValueError: when both `strides > 1` and `dilation_rate > 1`.
"""
def __init__(self,
@ -669,7 +698,7 @@ class Conv3D(Conv):
>>> input_shape =(4, 28, 28, 28, 1)
>>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv3D(
... 2, 3, activation='relu', input_shape=input_shape)(x)
... 2, 3, activation='relu', input_shape=input_shape[1:])(x)
>>> print(y.shape)
(4, 26, 26, 26, 2)
@ -1198,8 +1227,8 @@ class Conv2DTranspose(Conv2D):
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if len(input_shape) != 4:
raise ValueError('Inputs should have rank 4. Received input shape: ' +
str(input_shape))
raise ValueError('Inputs should have rank 4. Received input '
'shape: ' + str(input_shape))
channel_axis = self._get_channel_axis()
if input_shape.dims[channel_axis].value is None:
raise ValueError('The channel dimension of the inputs '

View File

@ -171,6 +171,21 @@ class Conv2DTest(keras_parameterized.TestCase):
input_shape=(num_samples, num_row, num_col, stack_size),
expected_output_shape=expected_output_shape)
def _run_test_extra_batch_dim(self, kwargs, expected_output_shape):
batch_shape = (2, 11)
stack_size = 3
num_row = 7
num_col = 6
with self.cached_session(use_gpu=True):
if expected_output_shape is not None:
expected_output_shape = (None,) + expected_output_shape
testing_utils.layer_test(
keras.layers.Conv2D,
kwargs=kwargs,
input_shape=batch_shape + (num_row, num_col, stack_size),
expected_output_shape=expected_output_shape)
@parameterized.named_parameters(
('padding_valid', {
'padding': 'valid'
@ -205,6 +220,7 @@ class Conv2DTest(keras_parameterized.TestCase):
kwargs['kernel_size'] = (3, 3)
if not requires_gpu or test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, expected_output_shape)
self._run_test_extra_batch_dim(kwargs, expected_output_shape)
def test_conv2d_regularizers(self):
kwargs = {

View File

@ -254,7 +254,7 @@ class _NonAtrousConvolution(object):
name=self.name)
def _squeeze_batch_dims(inp, op, inner_rank, name):
def squeeze_batch_dims(inp, op, inner_rank, name=None):
"""Returns `unsqueeze_batch(op(squeeze_batch(inp)))`.
Where `squeeze_batch` reshapes `inp` to shape
@ -272,7 +272,7 @@ def _squeeze_batch_dims(inp, op, inner_rank, name):
Returns:
`unsqueeze_batch_op(squeeze_batch(inp))`.
"""
with ops.name_scope(name, "Convolution", [inp]):
with ops.name_scope(name, "squeeze_batch_dims", [inp]):
inp = ops.convert_to_tensor(inp, name="input")
shape = inp.shape
@ -2224,7 +2224,7 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value
data_format=data_format,
dilations=dilations,
name=name)
return _squeeze_batch_dims(
return squeeze_batch_dims(
input,
functools.partial(
gen_nn_ops.conv2d,
@ -2543,7 +2543,7 @@ def _conv2d_expanded_batch(
data_format=data_format,
dilations=dilations,
name=name)
return _squeeze_batch_dims(
return squeeze_batch_dims(
input,
functools.partial(
gen_nn_ops.conv2d,