[TF] Add support for higher rank batch shape to keras.layers.Conv{,2D}.
The support is added to the base Conv layer, but the underlying 1D & 3D NN ops don't yet support it. But that's very easy to add now; so if this change goes in, I can add support for nn_ops conv1d, conv3d, and keras Conv1D and Conv3D layers all together. PiperOrigin-RevId: 314331499 Change-Id: Id1f723993881c206d006711afd3fc7921459df11
This commit is contained in:
parent
f1dd2e7a45
commit
4b123fc001
@ -162,7 +162,7 @@ class Conv(Layer):
|
|||||||
self.bias_regularizer = regularizers.get(bias_regularizer)
|
self.bias_regularizer = regularizers.get(bias_regularizer)
|
||||||
self.kernel_constraint = constraints.get(kernel_constraint)
|
self.kernel_constraint = constraints.get(kernel_constraint)
|
||||||
self.bias_constraint = constraints.get(bias_constraint)
|
self.bias_constraint = constraints.get(bias_constraint)
|
||||||
self.input_spec = InputSpec(ndim=self.rank + 2)
|
self.input_spec = InputSpec(min_ndim=self.rank + 2)
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
input_shape = tensor_shape.TensorShape(input_shape)
|
input_shape = tensor_shape.TensorShape(input_shape)
|
||||||
@ -196,7 +196,7 @@ class Conv(Layer):
|
|||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
channel_axis = self._get_channel_axis()
|
channel_axis = self._get_channel_axis()
|
||||||
self.input_spec = InputSpec(ndim=self.rank + 2,
|
self.input_spec = InputSpec(min_ndim=self.rank + 2,
|
||||||
axes={channel_axis: input_channel})
|
axes={channel_axis: input_channel})
|
||||||
|
|
||||||
self._build_conv_op_input_shape = input_shape
|
self._build_conv_op_input_shape = input_shape
|
||||||
@ -210,18 +210,20 @@ class Conv(Layer):
|
|||||||
dilation_rate=self.dilation_rate,
|
dilation_rate=self.dilation_rate,
|
||||||
strides=self.strides,
|
strides=self.strides,
|
||||||
padding=self._padding_op,
|
padding=self._padding_op,
|
||||||
data_format=self._conv_op_data_format)
|
data_format=self._conv_op_data_format,
|
||||||
|
num_spatial_dims=self.rank)
|
||||||
self.built = True
|
self.built = True
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
if self._recreate_conv_op(inputs):
|
if self._recreate_conv_op(inputs):
|
||||||
self._convolution_op = nn_ops.Convolution(
|
self._convolution_op = nn_ops.Convolution(
|
||||||
inputs.get_shape(),
|
inputs.shape,
|
||||||
filter_shape=self.kernel.shape,
|
filter_shape=self.kernel.shape,
|
||||||
dilation_rate=self.dilation_rate,
|
dilation_rate=self.dilation_rate,
|
||||||
strides=self.strides,
|
strides=self.strides,
|
||||||
padding=self._padding_op,
|
padding=self._padding_op,
|
||||||
data_format=self._conv_op_data_format)
|
data_format=self._conv_op_data_format,
|
||||||
|
num_spatial_dims=self.rank)
|
||||||
self._build_conv_op_input_shape = inputs.get_shape()
|
self._build_conv_op_input_shape = inputs.get_shape()
|
||||||
|
|
||||||
# Apply causal padding to inputs for Conv1D.
|
# Apply causal padding to inputs for Conv1D.
|
||||||
@ -231,15 +233,30 @@ class Conv(Layer):
|
|||||||
outputs = self._convolution_op(inputs, self.kernel)
|
outputs = self._convolution_op(inputs, self.kernel)
|
||||||
|
|
||||||
if self.use_bias:
|
if self.use_bias:
|
||||||
|
outputs_rank = outputs.shape.ndims
|
||||||
if self.data_format == 'channels_first':
|
if self.data_format == 'channels_first':
|
||||||
if self.rank == 1:
|
if self.rank == 1:
|
||||||
# nn.bias_add does not accept a 1D input tensor.
|
# nn.bias_add does not accept a 1D input tensor.
|
||||||
bias = array_ops.reshape(self.bias, (1, self.filters, 1))
|
bias = array_ops.reshape(self.bias, (1, self.filters, 1))
|
||||||
outputs += bias
|
outputs += bias
|
||||||
else:
|
else:
|
||||||
outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
|
if outputs_rank is not None and outputs_rank > 2 + self.rank:
|
||||||
|
# larger batch rank
|
||||||
|
outputs = nn_ops.squeeze_batch_dims(
|
||||||
|
outputs,
|
||||||
|
lambda o: nn.bias_add(o, self.bias, data_format='NCHW'),
|
||||||
|
inner_rank=self.rank + 1)
|
||||||
|
else:
|
||||||
|
outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
|
||||||
else:
|
else:
|
||||||
outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
|
if outputs_rank is not None and outputs_rank > 2 + self.rank:
|
||||||
|
# larger batch rank
|
||||||
|
outputs = nn_ops.squeeze_batch_dims(
|
||||||
|
outputs,
|
||||||
|
lambda o: nn.bias_add(o, self.bias, data_format='NHWC'),
|
||||||
|
inner_rank=self.rank + 1)
|
||||||
|
else:
|
||||||
|
outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
|
||||||
|
|
||||||
if self.activation is not None:
|
if self.activation is not None:
|
||||||
return self.activation(outputs)
|
return self.activation(outputs)
|
||||||
@ -258,14 +275,16 @@ class Conv(Layer):
|
|||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||||
|
batch_rank = len(input_shape) - self.rank - 1
|
||||||
if self.data_format == 'channels_last':
|
if self.data_format == 'channels_last':
|
||||||
return tensor_shape.TensorShape(
|
return tensor_shape.TensorShape(
|
||||||
[input_shape[0]] + self._spatial_output_shape(input_shape[1:-1]) +
|
input_shape[:batch_rank]
|
||||||
[self.filters])
|
+ self._spatial_output_shape(input_shape[batch_rank:-1])
|
||||||
|
+ [self.filters])
|
||||||
else:
|
else:
|
||||||
return tensor_shape.TensorShape(
|
return tensor_shape.TensorShape(
|
||||||
[input_shape[0], self.filters] +
|
input_shape[:batch_rank] + [self.filters] +
|
||||||
self._spatial_output_shape(input_shape[2:]))
|
self._spatial_output_shape(input_shape[batch_rank + 1:]))
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = {
|
config = {
|
||||||
@ -316,7 +335,7 @@ class Conv(Layer):
|
|||||||
|
|
||||||
def _get_channel_axis(self):
|
def _get_channel_axis(self):
|
||||||
if self.data_format == 'channels_first':
|
if self.data_format == 'channels_first':
|
||||||
return 1
|
return -1 - self.rank
|
||||||
else:
|
else:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
@ -350,7 +369,7 @@ class Conv(Layer):
|
|||||||
Returns:
|
Returns:
|
||||||
`True` or `False` to indicate whether to recreate the conv_op.
|
`True` or `False` to indicate whether to recreate the conv_op.
|
||||||
"""
|
"""
|
||||||
call_input_shape = inputs.get_shape()
|
call_input_shape = inputs.shape
|
||||||
# If the most specific compatible shape between _build_input_shape and
|
# If the most specific compatible shape between _build_input_shape and
|
||||||
# call_input_shape is not _build_input_shape then we must re-build.
|
# call_input_shape is not _build_input_shape then we must re-build.
|
||||||
return self._build_conv_op_input_shape.most_specific_compatible_shape(
|
return self._build_conv_op_input_shape.most_specific_compatible_shape(
|
||||||
@ -381,7 +400,7 @@ class Conv1D(Conv):
|
|||||||
>>> input_shape = (4, 10, 128)
|
>>> input_shape = (4, 10, 128)
|
||||||
>>> x = tf.random.normal(input_shape)
|
>>> x = tf.random.normal(input_shape)
|
||||||
>>> y = tf.keras.layers.Conv1D(
|
>>> y = tf.keras.layers.Conv1D(
|
||||||
... 32, 3, activation='relu',input_shape=input_shape)(x)
|
... 32, 3, activation='relu',input_shape=input_shape[1:])(x)
|
||||||
>>> print(y.shape)
|
>>> print(y.shape)
|
||||||
(4, 8, 32)
|
(4, 8, 32)
|
||||||
|
|
||||||
@ -508,7 +527,7 @@ class Conv2D(Conv):
|
|||||||
>>> input_shape = (4, 28, 28, 3)
|
>>> input_shape = (4, 28, 28, 3)
|
||||||
>>> x = tf.random.normal(input_shape)
|
>>> x = tf.random.normal(input_shape)
|
||||||
>>> y = tf.keras.layers.Conv2D(
|
>>> y = tf.keras.layers.Conv2D(
|
||||||
... 2, 3, activation='relu', input_shape=input_shape)(x)
|
... 2, 3, activation='relu', input_shape=input_shape[1:])(x)
|
||||||
>>> print(y.shape)
|
>>> print(y.shape)
|
||||||
(4, 26, 26, 2)
|
(4, 26, 26, 2)
|
||||||
|
|
||||||
@ -516,7 +535,7 @@ class Conv2D(Conv):
|
|||||||
>>> input_shape = (4, 28, 28, 3)
|
>>> input_shape = (4, 28, 28, 3)
|
||||||
>>> x = tf.random.normal(input_shape)
|
>>> x = tf.random.normal(input_shape)
|
||||||
>>> y = tf.keras.layers.Conv2D(
|
>>> y = tf.keras.layers.Conv2D(
|
||||||
... 2, 3, activation='relu', dilation_rate=2, input_shape=input_shape)(x)
|
... 2, 3, activation='relu', dilation_rate=2, input_shape=input_shape[1:])(x)
|
||||||
>>> print(y.shape)
|
>>> print(y.shape)
|
||||||
(4, 24, 24, 2)
|
(4, 24, 24, 2)
|
||||||
|
|
||||||
@ -524,10 +543,18 @@ class Conv2D(Conv):
|
|||||||
>>> input_shape = (4, 28, 28, 3)
|
>>> input_shape = (4, 28, 28, 3)
|
||||||
>>> x = tf.random.normal(input_shape)
|
>>> x = tf.random.normal(input_shape)
|
||||||
>>> y = tf.keras.layers.Conv2D(
|
>>> y = tf.keras.layers.Conv2D(
|
||||||
... 2, 3, activation='relu', padding="same", input_shape=input_shape)(x)
|
... 2, 3, activation='relu', padding="same", input_shape=input_shape[1:])(x)
|
||||||
>>> print(y.shape)
|
>>> print(y.shape)
|
||||||
(4, 28, 28, 2)
|
(4, 28, 28, 2)
|
||||||
|
|
||||||
|
>>> # With extended batch shape [4, 7]:
|
||||||
|
>>> input_shape = (4, 7, 28, 28, 3)
|
||||||
|
>>> x = tf.random.normal(input_shape)
|
||||||
|
>>> y = tf.keras.layers.Conv2D(
|
||||||
|
... 2, 3, activation='relu', input_shape=input_shape[2:])(x)
|
||||||
|
>>> print(y.shape)
|
||||||
|
(4, 7, 26, 26, 2)
|
||||||
|
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
filters: Integer, the dimensionality of the output space
|
filters: Integer, the dimensionality of the output space
|
||||||
@ -552,7 +579,7 @@ class Conv2D(Conv):
|
|||||||
`(batch_size, channels, height, width)`.
|
`(batch_size, channels, height, width)`.
|
||||||
It defaults to the `image_data_format` value found in your
|
It defaults to the `image_data_format` value found in your
|
||||||
Keras config file at `~/.keras/keras.json`.
|
Keras config file at `~/.keras/keras.json`.
|
||||||
If you never set it, then it will be "channels_last".
|
If you never set it, then it will be `channels_last`.
|
||||||
dilation_rate: an integer or tuple/list of 2 integers, specifying
|
dilation_rate: an integer or tuple/list of 2 integers, specifying
|
||||||
the dilation rate to use for dilated convolution.
|
the dilation rate to use for dilated convolution.
|
||||||
Can be a single integer to specify the same value for
|
Can be a single integer to specify the same value for
|
||||||
@ -585,25 +612,27 @@ class Conv2D(Conv):
|
|||||||
see `keras.constraints`).
|
see `keras.constraints`).
|
||||||
|
|
||||||
Input shape:
|
Input shape:
|
||||||
4D tensor with shape:
|
4+D tensor with shape:
|
||||||
`(batch_size, channels, rows, cols)` if data_format='channels_first'
|
`batch_shape + (channels, rows, cols)` if `data_format='channels_first'`
|
||||||
or 4D tensor with shape:
|
or 4+D tensor with shape:
|
||||||
`(batch_size, rows, cols, channels)` if data_format='channels_last'.
|
`batch_shape + (rows, cols, channels)` if `data_format='channels_last'`.
|
||||||
|
|
||||||
Output shape:
|
Output shape:
|
||||||
4D tensor with shape:
|
4+D tensor with shape:
|
||||||
`(batch_size, filters, new_rows, new_cols)` if data_format='channels_first'
|
`batch_shape + (filters, new_rows, new_cols)` if
|
||||||
or 4D tensor with shape:
|
`data_format='channels_first'` or 4+D tensor with shape:
|
||||||
`(batch_size, new_rows, new_cols, filters)` if data_format='channels_last'.
|
`batch_shape + (new_rows, new_cols, filters)` if
|
||||||
|
`data_format='channels_last'`.
|
||||||
|
|
||||||
`rows` and `cols` values might have changed due to padding.
|
`rows` and `cols` values might have changed due to padding.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tensor of rank 4 representing
|
A tensor of rank 4+ representing
|
||||||
`activation(conv2d(inputs, kernel) + bias)`.
|
`activation(conv2d(inputs, kernel) + bias)`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `padding` is "causal".
|
ValueError: if `padding` is `"causal"`.
|
||||||
ValueError: when both `strides` > 1 and `dilation_rate` > 1.
|
ValueError: when both `strides > 1` and `dilation_rate > 1`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -669,7 +698,7 @@ class Conv3D(Conv):
|
|||||||
>>> input_shape =(4, 28, 28, 28, 1)
|
>>> input_shape =(4, 28, 28, 28, 1)
|
||||||
>>> x = tf.random.normal(input_shape)
|
>>> x = tf.random.normal(input_shape)
|
||||||
>>> y = tf.keras.layers.Conv3D(
|
>>> y = tf.keras.layers.Conv3D(
|
||||||
... 2, 3, activation='relu', input_shape=input_shape)(x)
|
... 2, 3, activation='relu', input_shape=input_shape[1:])(x)
|
||||||
>>> print(y.shape)
|
>>> print(y.shape)
|
||||||
(4, 26, 26, 26, 2)
|
(4, 26, 26, 26, 2)
|
||||||
|
|
||||||
@ -1198,8 +1227,8 @@ class Conv2DTranspose(Conv2D):
|
|||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
input_shape = tensor_shape.TensorShape(input_shape)
|
input_shape = tensor_shape.TensorShape(input_shape)
|
||||||
if len(input_shape) != 4:
|
if len(input_shape) != 4:
|
||||||
raise ValueError('Inputs should have rank 4. Received input shape: ' +
|
raise ValueError('Inputs should have rank 4. Received input '
|
||||||
str(input_shape))
|
'shape: ' + str(input_shape))
|
||||||
channel_axis = self._get_channel_axis()
|
channel_axis = self._get_channel_axis()
|
||||||
if input_shape.dims[channel_axis].value is None:
|
if input_shape.dims[channel_axis].value is None:
|
||||||
raise ValueError('The channel dimension of the inputs '
|
raise ValueError('The channel dimension of the inputs '
|
||||||
|
@ -171,6 +171,21 @@ class Conv2DTest(keras_parameterized.TestCase):
|
|||||||
input_shape=(num_samples, num_row, num_col, stack_size),
|
input_shape=(num_samples, num_row, num_col, stack_size),
|
||||||
expected_output_shape=expected_output_shape)
|
expected_output_shape=expected_output_shape)
|
||||||
|
|
||||||
|
def _run_test_extra_batch_dim(self, kwargs, expected_output_shape):
|
||||||
|
batch_shape = (2, 11)
|
||||||
|
stack_size = 3
|
||||||
|
num_row = 7
|
||||||
|
num_col = 6
|
||||||
|
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
if expected_output_shape is not None:
|
||||||
|
expected_output_shape = (None,) + expected_output_shape
|
||||||
|
testing_utils.layer_test(
|
||||||
|
keras.layers.Conv2D,
|
||||||
|
kwargs=kwargs,
|
||||||
|
input_shape=batch_shape + (num_row, num_col, stack_size),
|
||||||
|
expected_output_shape=expected_output_shape)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('padding_valid', {
|
('padding_valid', {
|
||||||
'padding': 'valid'
|
'padding': 'valid'
|
||||||
@ -205,6 +220,7 @@ class Conv2DTest(keras_parameterized.TestCase):
|
|||||||
kwargs['kernel_size'] = (3, 3)
|
kwargs['kernel_size'] = (3, 3)
|
||||||
if not requires_gpu or test.is_gpu_available(cuda_only=True):
|
if not requires_gpu or test.is_gpu_available(cuda_only=True):
|
||||||
self._run_test(kwargs, expected_output_shape)
|
self._run_test(kwargs, expected_output_shape)
|
||||||
|
self._run_test_extra_batch_dim(kwargs, expected_output_shape)
|
||||||
|
|
||||||
def test_conv2d_regularizers(self):
|
def test_conv2d_regularizers(self):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
@ -254,7 +254,7 @@ class _NonAtrousConvolution(object):
|
|||||||
name=self.name)
|
name=self.name)
|
||||||
|
|
||||||
|
|
||||||
def _squeeze_batch_dims(inp, op, inner_rank, name):
|
def squeeze_batch_dims(inp, op, inner_rank, name=None):
|
||||||
"""Returns `unsqueeze_batch(op(squeeze_batch(inp)))`.
|
"""Returns `unsqueeze_batch(op(squeeze_batch(inp)))`.
|
||||||
|
|
||||||
Where `squeeze_batch` reshapes `inp` to shape
|
Where `squeeze_batch` reshapes `inp` to shape
|
||||||
@ -272,7 +272,7 @@ def _squeeze_batch_dims(inp, op, inner_rank, name):
|
|||||||
Returns:
|
Returns:
|
||||||
`unsqueeze_batch_op(squeeze_batch(inp))`.
|
`unsqueeze_batch_op(squeeze_batch(inp))`.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "Convolution", [inp]):
|
with ops.name_scope(name, "squeeze_batch_dims", [inp]):
|
||||||
inp = ops.convert_to_tensor(inp, name="input")
|
inp = ops.convert_to_tensor(inp, name="input")
|
||||||
shape = inp.shape
|
shape = inp.shape
|
||||||
|
|
||||||
@ -2224,7 +2224,7 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value
|
|||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
dilations=dilations,
|
dilations=dilations,
|
||||||
name=name)
|
name=name)
|
||||||
return _squeeze_batch_dims(
|
return squeeze_batch_dims(
|
||||||
input,
|
input,
|
||||||
functools.partial(
|
functools.partial(
|
||||||
gen_nn_ops.conv2d,
|
gen_nn_ops.conv2d,
|
||||||
@ -2543,7 +2543,7 @@ def _conv2d_expanded_batch(
|
|||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
dilations=dilations,
|
dilations=dilations,
|
||||||
name=name)
|
name=name)
|
||||||
return _squeeze_batch_dims(
|
return squeeze_batch_dims(
|
||||||
input,
|
input,
|
||||||
functools.partial(
|
functools.partial(
|
||||||
gen_nn_ops.conv2d,
|
gen_nn_ops.conv2d,
|
||||||
|
Loading…
Reference in New Issue
Block a user