[TF] Add extra batch dim support to tf.nn.conv{1,3}d and keras.layers.Conv{1,3}D.

This extends the support just added for tf.nn.conv2d and keras.layers.Conv2D.

PiperOrigin-RevId: 314738423
Change-Id: I588c36e493d1f41a67b9721e2cbd84c564277f44
This commit is contained in:
Eugene Brevdo 2020-06-04 08:59:35 -07:00 committed by TensorFlower Gardener
parent 47f7695c00
commit 549e69ca13
5 changed files with 287 additions and 87 deletions

View File

@ -228,7 +228,7 @@ class Conv(Layer):
# Apply causal padding to inputs for Conv1D. # Apply causal padding to inputs for Conv1D.
if self.padding == 'causal' and self.__class__.__name__ == 'Conv1D': if self.padding == 'causal' and self.__class__.__name__ == 'Conv1D':
inputs = array_ops.pad(inputs, self._compute_causal_padding()) inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs))
outputs = self._convolution_op(inputs, self.kernel) outputs = self._convolution_op(inputs, self.kernel)
@ -324,13 +324,17 @@ class Conv(Layer):
base_config = super(Conv, self).get_config() base_config = super(Conv, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def _compute_causal_padding(self): def _compute_causal_padding(self, inputs):
"""Calculates padding for 'causal' option for 1-d conv layers.""" """Calculates padding for 'causal' option for 1-d conv layers."""
left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1) left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)
if self.data_format == 'channels_last': if getattr(inputs.shape, 'ndims', None) is None:
causal_padding = [[0, 0], [left_pad, 0], [0, 0]] batch_rank = 1
else: else:
causal_padding = [[0, 0], [0, 0], [left_pad, 0]] batch_rank = len(inputs.shape) - 2
if self.data_format == 'channels_last':
causal_padding = [[0, 0]] * batch_rank + [[left_pad, 0], [0, 0]]
else:
causal_padding = [[0, 0]] * batch_rank + [[0, 0], [left_pad, 0]]
return causal_padding return causal_padding
def _get_channel_axis(self): def _get_channel_axis(self):
@ -404,6 +408,16 @@ class Conv1D(Conv):
>>> print(y.shape) >>> print(y.shape)
(4, 8, 32) (4, 8, 32)
>>> # With extended batch shape [4, 7] (e.g. weather data where batch
>>> # dimensions correspond to spatial location and the third dimension
>>> # corresponds to time.)
>>> input_shape = (4, 7, 10, 128)
>>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv1D(
... 32, 3, activation='relu', input_shape=input_shape[2:])(x)
>>> print(y.shape)
(4, 7, 8, 32)
Arguments: Arguments:
filters: Integer, the dimensionality of the output space filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution). (i.e. the number of output filters in the convolution).
@ -451,10 +465,10 @@ class Conv1D(Conv):
see `keras.constraints`). see `keras.constraints`).
Input shape: Input shape:
3D tensor with shape: `(batch_size, steps, input_dim)` 3+D tensor with shape: `batch_shape + (steps, input_dim)`
Output shape: Output shape:
3D tensor with shape: `(batch_size, new_steps, filters)` 3+D tensor with shape: `batch_shape + (new_steps, filters)`
`steps` value might have changed due to padding or strides. `steps` value might have changed due to padding or strides.
Returns: Returns:
@ -462,7 +476,7 @@ class Conv1D(Conv):
`activation(conv1d(inputs, kernel) + bias)`. `activation(conv1d(inputs, kernel) + bias)`.
Raises: Raises:
ValueError: when both `strides` > 1 and `dilation_rate` > 1. ValueError: when both `strides > 1` and `dilation_rate > 1`.
""" """
def __init__(self, def __init__(self,
@ -702,6 +716,15 @@ class Conv3D(Conv):
>>> print(y.shape) >>> print(y.shape)
(4, 26, 26, 26, 2) (4, 26, 26, 26, 2)
>>> # With extended batch shape [4, 7], e.g. a batch of 4 videos of 3D frames,
>>> # with 7 frames per video.
>>> input_shape = (4, 7, 28, 28, 28, 1)
>>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv3D(
... 2, 3, activation='relu', input_shape=input_shape[2:])(x)
>>> print(y.shape)
(4, 7, 26, 26, 26, 2)
Arguments: Arguments:
filters: Integer, the dimensionality of the output space filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution). (i.e. the number of output filters in the convolution).
@ -721,9 +744,9 @@ class Conv3D(Conv):
one of `channels_last` (default) or `channels_first`. one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs. The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape `channels_last` corresponds to inputs with shape
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` `batch_shape + (spatial_dim1, spatial_dim2, spatial_dim3, channels)`
while `channels_first` corresponds to inputs with shape while `channels_first` corresponds to inputs with shape
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. `batch_shape + (channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
It defaults to the `image_data_format` value found in your It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`. Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last". If you never set it, then it will be "channels_last".
@ -760,30 +783,30 @@ class Conv3D(Conv):
see `keras.constraints`). see `keras.constraints`).
Input shape: Input shape:
5D tensor with shape: 5+D tensor with shape:
`(batch_size, channels, conv_dim1, conv_dim2, conv_dim3)` if `batch_shape + (channels, conv_dim1, conv_dim2, conv_dim3)` if
data_format='channels_first' data_format='channels_first'
or 5D tensor with shape: or 5+D tensor with shape:
`(batch_size, conv_dim1, conv_dim2, conv_dim3, channels)` if `batch_shape + (conv_dim1, conv_dim2, conv_dim3, channels)` if
data_format='channels_last'. data_format='channels_last'.
Output shape: Output shape:
5D tensor with shape: 5+D tensor with shape:
`(batch_size, filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if `batch_shape + (filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if
data_format='channels_first' data_format='channels_first'
or 5D tensor with shape: or 5+D tensor with shape:
`(batch_size, new_conv_dim1, new_conv_dim2, new_conv_dim3, filters)` if `batch_shape + (new_conv_dim1, new_conv_dim2, new_conv_dim3, filters)` if
data_format='channels_last'. data_format='channels_last'.
`new_conv_dim1`, `new_conv_dim2` and `new_conv_dim3` values might have `new_conv_dim1`, `new_conv_dim2` and `new_conv_dim3` values might have
changed due to padding. changed due to padding.
Returns: Returns:
A tensor of rank 5 representing A tensor of rank 5+ representing
`activation(conv3d(inputs, kernel) + bias)`. `activation(conv3d(inputs, kernel) + bias)`.
Raises: Raises:
ValueError: if `padding` is "causal". ValueError: if `padding` is "causal".
ValueError: when both `strides` > 1 and `dilation_rate` > 1. ValueError: when both `strides > 1` and `dilation_rate > 1`.
""" """
def __init__(self, def __init__(self,

View File

@ -49,6 +49,21 @@ class Conv1DTest(keras_parameterized.TestCase):
input_shape=(num_samples, length, stack_size), input_shape=(num_samples, length, stack_size),
expected_output_shape=expected_output_shape) expected_output_shape=expected_output_shape)
def _run_test_extra_batch_dim(self, kwargs, expected_output_shape):
batch_shape = (2, 11)
stack_size = 3
length = 7
with self.cached_session(use_gpu=True):
if expected_output_shape is not None:
expected_output_shape = (None,) + expected_output_shape
testing_utils.layer_test(
keras.layers.Conv1D,
kwargs=kwargs,
input_shape=batch_shape + (length, stack_size),
expected_output_shape=expected_output_shape)
@parameterized.named_parameters( @parameterized.named_parameters(
('padding_valid', { ('padding_valid', {
'padding': 'valid' 'padding': 'valid'
@ -85,6 +100,7 @@ class Conv1DTest(keras_parameterized.TestCase):
kwargs['kernel_size'] = 3 kwargs['kernel_size'] = 3
if not requires_gpu or test.is_gpu_available(cuda_only=True): if not requires_gpu or test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, expected_output_shape) self._run_test(kwargs, expected_output_shape)
self._run_test_extra_batch_dim(kwargs, expected_output_shape)
def test_conv1d_regularizers(self): def test_conv1d_regularizers(self):
kwargs = { kwargs = {
@ -281,6 +297,27 @@ class Conv3DTest(keras_parameterized.TestCase):
expected_output_shape=expected_output_shape, expected_output_shape=expected_output_shape,
validate_training=validate_training) validate_training=validate_training)
def _run_test_extra_batch_dim(self,
kwargs,
expected_output_shape,
validate_training=True):
batch_shape = (2, 11)
stack_size = 3
num_row = 7
num_col = 6
depth = 5
with self.cached_session(use_gpu=True):
if expected_output_shape is not None:
expected_output_shape = (None,) + expected_output_shape
testing_utils.layer_test(
keras.layers.Conv3D,
kwargs=kwargs,
input_shape=batch_shape + (depth, num_row, num_col, stack_size),
expected_output_shape=expected_output_shape,
validate_training=validate_training)
@parameterized.named_parameters( @parameterized.named_parameters(
('padding_valid', { ('padding_valid', {
'padding': 'valid' 'padding': 'valid'
@ -313,6 +350,8 @@ class Conv3DTest(keras_parameterized.TestCase):
test_training = 'groups' not in kwargs or not test_util.is_xla_enabled() test_training = 'groups' not in kwargs or not test_util.is_xla_enabled()
if not requires_gpu or test.is_gpu_available(cuda_only=True): if not requires_gpu or test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, expected_output_shape, test_training) self._run_test(kwargs, expected_output_shape, test_training)
self._run_test_extra_batch_dim(kwargs, expected_output_shape,
test_training)
def test_conv3d_regularizers(self): def test_conv3d_regularizers(self):
kwargs = { kwargs = {

View File

@ -55,6 +55,32 @@ class Conv1DTest(test.TestCase):
self.assertEqual(len(output), 2) self.assertEqual(len(output), 2)
self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4]) self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4])
def testExpandedBatch(self):
"""Test that argument passing to conv1d is handled properly."""
# double datatype is currently not supported for convolution ops
# on the ROCm platform
x = constant_op.constant([1, 2, 3, 4], dtype=dtypes.float32)
x = array_ops.expand_dims(x, 0) # Add batch dimension
x = array_ops.expand_dims(x, 2) # And depth dimension
x = array_ops.stack([x, x]) # Make batch shape [2, 1]
filters = constant_op.constant([2, 1], dtype=dtypes.float32)
filters = array_ops.expand_dims(filters, 1) # in_channels
filters = array_ops.expand_dims(filters, 2) # out_channels
# Filters is 2x1x1
for stride in [1, 2]:
with self.cached_session(use_gpu=test.is_gpu_available()):
c = nn_ops.conv1d(x, filters, stride, padding="VALID")
reduced = array_ops.squeeze(c) # Sequeeze out dims 1 and 3.
output = self.evaluate(reduced)
if stride == 1:
self.assertAllClose(output,
[[2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4],
[2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4]])
else:
self.assertAllClose(
output,
[[2 * 1 + 1 * 2, 2 * 3 + 1 * 4], [2 * 1 + 1 * 2, 2 * 3 + 1 * 4]])
def testConv1DTranspose(self): def testConv1DTranspose(self):
with self.cached_session(): with self.cached_session():
stride = 2 stride = 2

View File

@ -66,12 +66,8 @@ class Conv3DTest(test.TestCase):
def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, stride, def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, stride,
padding, data_format, dtype, use_gpu): padding, data_format, dtype, use_gpu):
total_size_tensor = 1 total_size_tensor = np.prod(tensor_in_sizes)
total_size_filter = 1 total_size_filter = np.prod(filter_in_sizes)
for s in tensor_in_sizes:
total_size_tensor *= s
for s in filter_in_sizes:
total_size_filter *= s
# Initializes the input tensor with array containing numbers from 0 to 1. # Initializes the input tensor with array containing numbers from 0 to 1.
# We keep the input tensor values fairly small to avoid overflowing float16 # We keep the input tensor values fairly small to avoid overflowing float16
@ -126,12 +122,8 @@ class Conv3DTest(test.TestCase):
def _ComputeReferenceDilatedConv(self, tensor_in_sizes, filter_in_sizes, def _ComputeReferenceDilatedConv(self, tensor_in_sizes, filter_in_sizes,
stride, dilation, padding, data_format, stride, dilation, padding, data_format,
use_gpu): use_gpu):
total_size_tensor = 1 total_size_tensor = np.prod(tensor_in_sizes)
total_size_filter = 1 total_size_filter = np.prod(filter_in_sizes)
for s in tensor_in_sizes:
total_size_tensor *= s
for s in filter_in_sizes:
total_size_filter *= s
# Initializes the input tensor with array containing incrementing # Initializes the input tensor with array containing incrementing
# numbers from 1. # numbers from 1.
@ -196,6 +188,69 @@ class Conv3DTest(test.TestCase):
self.assertAllClose( self.assertAllClose(
e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-6) e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-6)
def _CreateNumpyTensor(self, sizes):
return np.asarray([f * 1.0
for f in range(1,
np.prod(sizes) + 1)]).reshape(sizes)
@test_util.run_in_graph_and_eager_modes
def testConv3DExpandedBatch(self):
tensor_in_sizes_batch = [10, 2, 3, 1, 3]
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3]
filter_in_sizes = [1, 1, 1, 3, 3]
filter_in = self._CreateNumpyTensor(filter_in_sizes)
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
conv1 = nn_ops.conv3d_v2(
x1, filter_in, strides=[1, 1, 1, 1, 1], padding="VALID")
conv2 = nn_ops.conv3d_v2(
x2, filter_in, strides=[1, 1, 1, 1, 1], padding="VALID")
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConvolutionClass3DExpandedBatch(self):
tensor_in_sizes_batch = [10, 2, 3, 1, 3]
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3]
filter_in_sizes = [1, 1, 1, 3, 3]
filter_in = self._CreateNumpyTensor(filter_in_sizes)
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
convolver1 = nn_ops.Convolution(
input_shape=x1.shape,
filter_shape=filter_in.shape,
strides=[1, 1, 1],
padding="VALID")
self.assertEqual(convolver1.num_batch_dims, 1)
convolver2 = nn_ops.Convolution(
input_shape=x2.shape,
filter_shape=filter_in.shape,
strides=[1, 1, 1],
padding="VALID")
self.assertEqual(convolver2.num_batch_dims, 2)
conv1 = convolver1(x1, filter_in)
conv2 = convolver2(x2, filter_in)
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self):
tensor_in_sizes_batch = [10, 2, 3, 1, 3]
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3]
filter_in_sizes = [1, 1, 1, 3, 3]
filter_in = self._CreateNumpyTensor(filter_in_sizes)
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
conv1 = nn_ops.convolution(
x1, filter_in, strides=[1, 1, 1], padding="VALID")
conv2 = nn_ops.convolution(
x2, filter_in, strides=[1, 1, 1], padding="VALID")
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape))
def testConv3D1x1x1Filter(self): def testConv3D1x1x1Filter(self):
expected_output = [ expected_output = [
0.18518519, 0.22222222, 0.25925926, 0.40740741, 0.5, 0.59259259, 0.18518519, 0.22222222, 0.25925926, 0.40740741, 0.5, 0.59259259,

View File

@ -228,7 +228,7 @@ class _NonAtrousConvolution(object):
% data_format) % data_format)
self.strides = strides self.strides = strides
self.data_format = data_format self.data_format = data_format
self.conv_op = gen_nn_ops.conv3d self.conv_op = _conv3d_expanded_batch
# Note that we need this adapter since argument names for conv1d don't match # Note that we need this adapter since argument names for conv1d don't match
# those for gen_nn_ops.conv2d and gen_nn_ops.conv3d. # those for gen_nn_ops.conv2d and gen_nn_ops.conv3d.
@ -241,7 +241,6 @@ class _NonAtrousConvolution(object):
padding=padding, padding=padding,
data_format=data_format, data_format=data_format,
name=name) name=name)
# pylint: enable=redefined-builtin # pylint: enable=redefined-builtin
def __call__(self, inp, filter): # pylint: disable=redefined-builtin def __call__(self, inp, filter): # pylint: disable=redefined-builtin
@ -1104,7 +1103,7 @@ def convolution_internal(
scope = "convolution" scope = "convolution"
with ops.name_scope(name, scope, [input, filters]) as name: with ops.name_scope(name, scope, [input, filters]) as name:
conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: gen_nn_ops.conv3d} conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: _conv3d_expanded_batch}
if device_context.enclosing_tpu_context() is not None or all( if device_context.enclosing_tpu_context() is not None or all(
i == 1 for i in dilations): i == 1 for i in dilations):
@ -1783,40 +1782,42 @@ def conv1d(
name=None, name=None,
input=None, # pylint: disable=redefined-builtin input=None, # pylint: disable=redefined-builtin
dilations=None): dilations=None):
r"""Computes a 1-D convolution given 3-D input and filter tensors. r"""Computes a 1-D convolution of input with rank `>=3` and a `3-D` filter.
Given an input tensor of shape Given an input tensor of shape
[batch, in_width, in_channels] `batch_shape + [in_width, in_channels]`
if data_format is "NWC", or if `data_format` is `"NWC"`, or
[batch, in_channels, in_width] `batch_shape + [in_channels, in_width]`
if data_format is "NCW", if `data_format` is `"NCW"`,
and a filter / kernel tensor of shape and a filter / kernel tensor of shape
[filter_width, in_channels, out_channels], this op reshapes `[filter_width, in_channels, out_channels]`, this op reshapes
the arguments to pass them to conv2d to perform the equivalent the arguments to pass them to `conv2d` to perform the equivalent
convolution operation. convolution operation.
Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`. Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
For example, if `data_format` does not start with "NC", a tensor of shape For example, if `data_format` does not start with "NC", a tensor of shape
[batch, in_width, in_channels] `batch_shape + [in_width, in_channels]`
is reshaped to is reshaped to
[batch, 1, in_width, in_channels], `batch_shape + [1, in_width, in_channels]`,
and the filter is reshaped to and the filter is reshaped to
[1, filter_width, in_channels, out_channels]. `[1, filter_width, in_channels, out_channels]`.
The result is then reshaped back to The result is then reshaped back to
[batch, out_width, out_channels] `batch_shape + [out_width, out_channels]`
\(where out_width is a function of the stride and padding as in conv2d\) and \(where out_width is a function of the stride and padding as in conv2d\) and
returned to the caller. returned to the caller.
Args: Args:
value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`. value: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or
filters: A 3D `Tensor`. Must have the same type as `value`. `float64`.
filters: A Tensor of rank at least 3. Must have the same type as `value`.
stride: An int or list of `ints` that has length `1` or `3`. The number of stride: An int or list of `ints` that has length `1` or `3`. The number of
entries by which the filter is moved right at each step. entries by which the filter is moved right at each step.
padding: 'SAME' or 'VALID' padding: 'SAME' or 'VALID'
use_cudnn_on_gpu: An optional `bool`. Defaults to `True`. use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
data_format: An optional `string` from `"NWC", "NCW"`. Defaults to `"NWC"`, data_format: An optional `string` from `"NWC", "NCW"`. Defaults to `"NWC"`,
the data is stored in the order of [batch, in_width, in_channels]. The the data is stored in the order of `batch_shape + [in_width,
`"NCW"` format stores data as [batch, in_channels, in_width]. in_channels]`. The `"NCW"` format stores data as `batch_shape +
[in_channels, in_width]`.
name: A name for the operation (optional). name: A name for the operation (optional).
input: Alias for value. input: Alias for value.
dilations: An int or list of `ints` that has length `1` or `3` which dilations: An int or list of `ints` that has length `1` or `3` which
@ -1832,14 +1833,14 @@ def conv1d(
""" """
value = deprecation.deprecated_argument_lookup("input", input, "value", value) value = deprecation.deprecated_argument_lookup("input", input, "value", value)
with ops.name_scope(name, "conv1d", [value, filters]) as name: with ops.name_scope(name, "conv1d", [value, filters]) as name:
# Reshape the input tensor to [batch, 1, in_width, in_channels] # Reshape the input tensor to batch_shape + [1, in_width, in_channels]
if data_format is None or data_format == "NHWC" or data_format == "NWC": if data_format is None or data_format == "NHWC" or data_format == "NWC":
data_format = "NHWC" data_format = "NHWC"
spatial_start_dim = 1 spatial_start_dim = -3
channel_index = 2 channel_index = 2
elif data_format == "NCHW" or data_format == "NCW": elif data_format == "NCHW" or data_format == "NCW":
data_format = "NCHW" data_format = "NCHW"
spatial_start_dim = 2 spatial_start_dim = -2
channel_index = 1 channel_index = 1
else: else:
raise ValueError("data_format must be \"NWC\" or \"NCW\".") raise ValueError("data_format must be \"NWC\" or \"NCW\".")
@ -1848,6 +1849,7 @@ def conv1d(
value = array_ops.expand_dims(value, spatial_start_dim) value = array_ops.expand_dims(value, spatial_start_dim)
filters = array_ops.expand_dims(filters, 0) filters = array_ops.expand_dims(filters, 0)
if value.shape.ndims in (4, 3, 2, 1, 0, None):
result = gen_nn_ops.conv2d( result = gen_nn_ops.conv2d(
value, value,
filters, filters,
@ -1857,6 +1859,20 @@ def conv1d(
data_format=data_format, data_format=data_format,
dilations=dilations, dilations=dilations,
name=name) name=name)
else:
result = squeeze_batch_dims(
value,
functools.partial(
gen_nn_ops.conv2d,
filter=filters,
strides=strides,
padding=padding,
use_cudnn_on_gpu=use_cudnn_on_gpu,
data_format=data_format,
dilations=dilations,
),
inner_rank=3,
name=name)
return array_ops.squeeze(result, [spatial_start_dim]) return array_ops.squeeze(result, [spatial_start_dim])
@ -1873,36 +1889,38 @@ def conv1d_v2(
r"""Computes a 1-D convolution given 3-D input and filter tensors. r"""Computes a 1-D convolution given 3-D input and filter tensors.
Given an input tensor of shape Given an input tensor of shape
[batch, in_width, in_channels] `batch_shape + [in_width, in_channels]`
if data_format is "NWC", or if `data_format` is `"NWC"`, or
[batch, in_channels, in_width] `batch_shape + [in_channels, in_width]`
if data_format is "NCW", if `data_format` is `"NCW"`,
and a filter / kernel tensor of shape and a filter / kernel tensor of shape
[filter_width, in_channels, out_channels], this op reshapes `[filter_width, in_channels, out_channels]`, this op reshapes
the arguments to pass them to conv2d to perform the equivalent the arguments to pass them to `conv2d` to perform the equivalent
convolution operation. convolution operation.
Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`. Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
For example, if `data_format` does not start with "NC", a tensor of shape For example, if `data_format` does not start with `"NC"`, a tensor of shape
[batch, in_width, in_channels] `batch_shape + [in_width, in_channels]`
is reshaped to is reshaped to
[batch, 1, in_width, in_channels], `batch_shape + [1, in_width, in_channels]`,
and the filter is reshaped to and the filter is reshaped to
[1, filter_width, in_channels, out_channels]. `[1, filter_width, in_channels, out_channels]`.
The result is then reshaped back to The result is then reshaped back to
[batch, out_width, out_channels] `batch_shape + [out_width, out_channels]`
\(where out_width is a function of the stride and padding as in conv2d\) and \(where out_width is a function of the stride and padding as in conv2d\) and
returned to the caller. returned to the caller.
Args: Args:
input: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`. input: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or
filters: A 3D `Tensor`. Must have the same type as `input`. `float64`.
filters: A Tensor of rank at least 3. Must have the same type as `input`.
stride: An int or list of `ints` that has length `1` or `3`. The number of stride: An int or list of `ints` that has length `1` or `3`. The number of
entries by which the filter is moved right at each step. entries by which the filter is moved right at each step.
padding: 'SAME' or 'VALID' padding: 'SAME' or 'VALID'
data_format: An optional `string` from `"NWC", "NCW"`. Defaults to `"NWC"`, data_format: An optional `string` from `"NWC", "NCW"`. Defaults to `"NWC"`,
the data is stored in the order of [batch, in_width, in_channels]. The the data is stored in the order of
`"NCW"` format stores data as [batch, in_channels, in_width]. `batch_shape + [in_width, in_channels]`. The `"NCW"` format stores data
as `batch_shape + [in_channels, in_width]`.
dilations: An int or list of `ints` that has length `1` or `3` which dilations: An int or list of `ints` that has length `1` or `3` which
defaults to 1. The dilation factor for each dimension of input. If set to defaults to 1. The dilation factor for each dimension of input. If set to
k > 1, there will be k-1 skipped cells between each filter element on that k > 1, there will be k-1 skipped cells between each filter element on that
@ -2071,9 +2089,9 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin
Args: Args:
input: A `Tensor`. Must be one of the following types: input: A `Tensor`. Must be one of the following types:
`half`, `bfloat16`, `float32`, `float64`. `half`, `bfloat16`, `float32`, `float64`.
A 4+-D tensor. The dimension order is interpreted according to the value A Tensor of rank at least 4. The dimension order is interpreted according
of `data_format`; with the all-but-inner-3 dimensions acting as batch to the value of `data_format`; with the all-but-inner-3 dimensions acting
dimensions. See below for details. as batch dimensions. See below for details.
filters: A `Tensor`. Must have the same type as `input`. filters: A `Tensor`. Must have the same type as `input`.
A 4-D tensor of shape A 4-D tensor of shape
`[filter_height, filter_width, in_channels, out_channels]` `[filter_height, filter_width, in_channels, out_channels]`
@ -2214,6 +2232,8 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value
ndims = getattr(shape, "ndims", -1) ndims = getattr(shape, "ndims", -1)
if ndims == -1: ndims = len(shape) if ndims == -1: ndims = len(shape)
if ndims in (4, 3, 2, 1, 0, None): if ndims in (4, 3, 2, 1, 0, None):
# We avoid calling squeeze_batch_dims to reduce extra python function
# call slowdown in eager mode. This branch doesn't require reshapes.
return gen_nn_ops.conv2d( return gen_nn_ops.conv2d(
input, input,
filter=filter, filter=filter,
@ -2535,6 +2555,8 @@ def _conv2d_expanded_batch(
ndims = getattr(shape, "ndims", -1) ndims = getattr(shape, "ndims", -1)
if ndims == -1: ndims = len(shape) if ndims == -1: ndims = len(shape)
if ndims in (4, 3, 2, 1, 0, None): if ndims in (4, 3, 2, 1, 0, None):
# We avoid calling squeeze_batch_dims to reduce extra python function
# call slowdown in eager mode. This branch doesn't require reshapes.
return gen_nn_ops.conv2d( return gen_nn_ops.conv2d(
input, input,
filter=filters, filter=filters,
@ -2931,6 +2953,46 @@ def depthwise_conv2d_native_backprop_filter( # pylint: disable=redefined-builti
name=name) name=name)
def _conv3d_expanded_batch(
input, # pylint: disable=redefined-builtin
filter, # pylint: disable=redefined-builtin
strides,
padding,
data_format,
dilations=None,
name=None):
"""Helper function for `conv3d`; handles expanded batches."""
# Try really hard to avoid modifying the legacy name sceops - return early.
shape = getattr(input, "shape", None)
if shape is not None:
ndims = getattr(shape, "ndims", -1)
if ndims == -1:
ndims = len(shape)
if ndims in (5, 4, 3, 2, 1, 0, None):
# We avoid calling squeeze_batch_dims to reduce extra python function
# call slowdown in eager mode. This branch doesn't require reshapes.
return gen_nn_ops.conv3d(
input,
filter,
strides,
padding,
data_format=data_format,
dilations=dilations,
name=name)
else:
return squeeze_batch_dims(
input,
functools.partial(
gen_nn_ops.conv3d,
filter=filter,
strides=strides,
padding=padding,
data_format=data_format,
dilations=dilations),
inner_rank=4,
name=name)
@tf_export("nn.conv3d", v1=[]) @tf_export("nn.conv3d", v1=[])
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring
@ -2942,13 +3004,8 @@ def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring
name=None): name=None):
if dilations is None: if dilations is None:
dilations = [1, 1, 1, 1, 1] dilations = [1, 1, 1, 1, 1]
return gen_nn_ops.conv3d(input, return _conv3d_expanded_batch(input, filters, strides, padding, data_format,
filters, dilations, name)
strides,
padding,
data_format=data_format,
dilations=dilations,
name=name)
@tf_export(v1=["nn.conv3d"]) @tf_export(v1=["nn.conv3d"])