[TF] Add extra batch dim support to tf.nn.conv{1,3}d and keras.layers.Conv{1,3}D.

This extends the support just added for tf.nn.conv2d and keras.layers.Conv2D.

PiperOrigin-RevId: 314738423
Change-Id: I588c36e493d1f41a67b9721e2cbd84c564277f44
This commit is contained in:
Eugene Brevdo 2020-06-04 08:59:35 -07:00 committed by TensorFlower Gardener
parent 47f7695c00
commit 549e69ca13
5 changed files with 287 additions and 87 deletions

View File

@ -228,7 +228,7 @@ class Conv(Layer):
# Apply causal padding to inputs for Conv1D.
if self.padding == 'causal' and self.__class__.__name__ == 'Conv1D':
inputs = array_ops.pad(inputs, self._compute_causal_padding())
inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs))
outputs = self._convolution_op(inputs, self.kernel)
@ -324,13 +324,17 @@ class Conv(Layer):
base_config = super(Conv, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _compute_causal_padding(self):
def _compute_causal_padding(self, inputs):
"""Calculates padding for 'causal' option for 1-d conv layers."""
left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)
if self.data_format == 'channels_last':
causal_padding = [[0, 0], [left_pad, 0], [0, 0]]
if getattr(inputs.shape, 'ndims', None) is None:
batch_rank = 1
else:
causal_padding = [[0, 0], [0, 0], [left_pad, 0]]
batch_rank = len(inputs.shape) - 2
if self.data_format == 'channels_last':
causal_padding = [[0, 0]] * batch_rank + [[left_pad, 0], [0, 0]]
else:
causal_padding = [[0, 0]] * batch_rank + [[0, 0], [left_pad, 0]]
return causal_padding
def _get_channel_axis(self):
@ -404,6 +408,16 @@ class Conv1D(Conv):
>>> print(y.shape)
(4, 8, 32)
>>> # With extended batch shape [4, 7] (e.g. weather data where batch
>>> # dimensions correspond to spatial location and the third dimension
>>> # corresponds to time.)
>>> input_shape = (4, 7, 10, 128)
>>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv1D(
... 32, 3, activation='relu', input_shape=input_shape[2:])(x)
>>> print(y.shape)
(4, 7, 8, 32)
Arguments:
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
@ -451,10 +465,10 @@ class Conv1D(Conv):
see `keras.constraints`).
Input shape:
3D tensor with shape: `(batch_size, steps, input_dim)`
3+D tensor with shape: `batch_shape + (steps, input_dim)`
Output shape:
3D tensor with shape: `(batch_size, new_steps, filters)`
3+D tensor with shape: `batch_shape + (new_steps, filters)`
`steps` value might have changed due to padding or strides.
Returns:
@ -462,7 +476,7 @@ class Conv1D(Conv):
`activation(conv1d(inputs, kernel) + bias)`.
Raises:
ValueError: when both `strides` > 1 and `dilation_rate` > 1.
ValueError: when both `strides > 1` and `dilation_rate > 1`.
"""
def __init__(self,
@ -702,6 +716,15 @@ class Conv3D(Conv):
>>> print(y.shape)
(4, 26, 26, 26, 2)
>>> # With extended batch shape [4, 7], e.g. a batch of 4 videos of 3D frames,
>>> # with 7 frames per video.
>>> input_shape = (4, 7, 28, 28, 28, 1)
>>> x = tf.random.normal(input_shape)
>>> y = tf.keras.layers.Conv3D(
... 2, 3, activation='relu', input_shape=input_shape[2:])(x)
>>> print(y.shape)
(4, 7, 26, 26, 26, 2)
Arguments:
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
@ -721,9 +744,9 @@ class Conv3D(Conv):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
`batch_shape + (spatial_dim1, spatial_dim2, spatial_dim3, channels)`
while `channels_first` corresponds to inputs with shape
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
`batch_shape + (channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@ -760,30 +783,30 @@ class Conv3D(Conv):
see `keras.constraints`).
Input shape:
5D tensor with shape:
`(batch_size, channels, conv_dim1, conv_dim2, conv_dim3)` if
5+D tensor with shape:
`batch_shape + (channels, conv_dim1, conv_dim2, conv_dim3)` if
data_format='channels_first'
or 5D tensor with shape:
`(batch_size, conv_dim1, conv_dim2, conv_dim3, channels)` if
or 5+D tensor with shape:
`batch_shape + (conv_dim1, conv_dim2, conv_dim3, channels)` if
data_format='channels_last'.
Output shape:
5D tensor with shape:
`(batch_size, filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if
5+D tensor with shape:
`batch_shape + (filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if
data_format='channels_first'
or 5D tensor with shape:
`(batch_size, new_conv_dim1, new_conv_dim2, new_conv_dim3, filters)` if
or 5+D tensor with shape:
`batch_shape + (new_conv_dim1, new_conv_dim2, new_conv_dim3, filters)` if
data_format='channels_last'.
`new_conv_dim1`, `new_conv_dim2` and `new_conv_dim3` values might have
changed due to padding.
Returns:
A tensor of rank 5 representing
A tensor of rank 5+ representing
`activation(conv3d(inputs, kernel) + bias)`.
Raises:
ValueError: if `padding` is "causal".
ValueError: when both `strides` > 1 and `dilation_rate` > 1.
ValueError: when both `strides > 1` and `dilation_rate > 1`.
"""
def __init__(self,

View File

@ -49,6 +49,21 @@ class Conv1DTest(keras_parameterized.TestCase):
input_shape=(num_samples, length, stack_size),
expected_output_shape=expected_output_shape)
def _run_test_extra_batch_dim(self, kwargs, expected_output_shape):
batch_shape = (2, 11)
stack_size = 3
length = 7
with self.cached_session(use_gpu=True):
if expected_output_shape is not None:
expected_output_shape = (None,) + expected_output_shape
testing_utils.layer_test(
keras.layers.Conv1D,
kwargs=kwargs,
input_shape=batch_shape + (length, stack_size),
expected_output_shape=expected_output_shape)
@parameterized.named_parameters(
('padding_valid', {
'padding': 'valid'
@ -85,6 +100,7 @@ class Conv1DTest(keras_parameterized.TestCase):
kwargs['kernel_size'] = 3
if not requires_gpu or test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, expected_output_shape)
self._run_test_extra_batch_dim(kwargs, expected_output_shape)
def test_conv1d_regularizers(self):
kwargs = {
@ -281,6 +297,27 @@ class Conv3DTest(keras_parameterized.TestCase):
expected_output_shape=expected_output_shape,
validate_training=validate_training)
def _run_test_extra_batch_dim(self,
kwargs,
expected_output_shape,
validate_training=True):
batch_shape = (2, 11)
stack_size = 3
num_row = 7
num_col = 6
depth = 5
with self.cached_session(use_gpu=True):
if expected_output_shape is not None:
expected_output_shape = (None,) + expected_output_shape
testing_utils.layer_test(
keras.layers.Conv3D,
kwargs=kwargs,
input_shape=batch_shape + (depth, num_row, num_col, stack_size),
expected_output_shape=expected_output_shape,
validate_training=validate_training)
@parameterized.named_parameters(
('padding_valid', {
'padding': 'valid'
@ -313,6 +350,8 @@ class Conv3DTest(keras_parameterized.TestCase):
test_training = 'groups' not in kwargs or not test_util.is_xla_enabled()
if not requires_gpu or test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, expected_output_shape, test_training)
self._run_test_extra_batch_dim(kwargs, expected_output_shape,
test_training)
def test_conv3d_regularizers(self):
kwargs = {

View File

@ -55,6 +55,32 @@ class Conv1DTest(test.TestCase):
self.assertEqual(len(output), 2)
self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4])
def testExpandedBatch(self):
"""Test that argument passing to conv1d is handled properly."""
# double datatype is currently not supported for convolution ops
# on the ROCm platform
x = constant_op.constant([1, 2, 3, 4], dtype=dtypes.float32)
x = array_ops.expand_dims(x, 0) # Add batch dimension
x = array_ops.expand_dims(x, 2) # And depth dimension
x = array_ops.stack([x, x]) # Make batch shape [2, 1]
filters = constant_op.constant([2, 1], dtype=dtypes.float32)
filters = array_ops.expand_dims(filters, 1) # in_channels
filters = array_ops.expand_dims(filters, 2) # out_channels
# Filters is 2x1x1
for stride in [1, 2]:
with self.cached_session(use_gpu=test.is_gpu_available()):
c = nn_ops.conv1d(x, filters, stride, padding="VALID")
reduced = array_ops.squeeze(c) # Sequeeze out dims 1 and 3.
output = self.evaluate(reduced)
if stride == 1:
self.assertAllClose(output,
[[2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4],
[2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4]])
else:
self.assertAllClose(
output,
[[2 * 1 + 1 * 2, 2 * 3 + 1 * 4], [2 * 1 + 1 * 2, 2 * 3 + 1 * 4]])
def testConv1DTranspose(self):
with self.cached_session():
stride = 2

View File

@ -66,12 +66,8 @@ class Conv3DTest(test.TestCase):
def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, stride,
padding, data_format, dtype, use_gpu):
total_size_tensor = 1
total_size_filter = 1
for s in tensor_in_sizes:
total_size_tensor *= s
for s in filter_in_sizes:
total_size_filter *= s
total_size_tensor = np.prod(tensor_in_sizes)
total_size_filter = np.prod(filter_in_sizes)
# Initializes the input tensor with array containing numbers from 0 to 1.
# We keep the input tensor values fairly small to avoid overflowing float16
@ -126,12 +122,8 @@ class Conv3DTest(test.TestCase):
def _ComputeReferenceDilatedConv(self, tensor_in_sizes, filter_in_sizes,
stride, dilation, padding, data_format,
use_gpu):
total_size_tensor = 1
total_size_filter = 1
for s in tensor_in_sizes:
total_size_tensor *= s
for s in filter_in_sizes:
total_size_filter *= s
total_size_tensor = np.prod(tensor_in_sizes)
total_size_filter = np.prod(filter_in_sizes)
# Initializes the input tensor with array containing incrementing
# numbers from 1.
@ -196,6 +188,69 @@ class Conv3DTest(test.TestCase):
self.assertAllClose(
e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-6)
def _CreateNumpyTensor(self, sizes):
return np.asarray([f * 1.0
for f in range(1,
np.prod(sizes) + 1)]).reshape(sizes)
@test_util.run_in_graph_and_eager_modes
def testConv3DExpandedBatch(self):
tensor_in_sizes_batch = [10, 2, 3, 1, 3]
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3]
filter_in_sizes = [1, 1, 1, 3, 3]
filter_in = self._CreateNumpyTensor(filter_in_sizes)
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
conv1 = nn_ops.conv3d_v2(
x1, filter_in, strides=[1, 1, 1, 1, 1], padding="VALID")
conv2 = nn_ops.conv3d_v2(
x2, filter_in, strides=[1, 1, 1, 1, 1], padding="VALID")
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConvolutionClass3DExpandedBatch(self):
tensor_in_sizes_batch = [10, 2, 3, 1, 3]
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3]
filter_in_sizes = [1, 1, 1, 3, 3]
filter_in = self._CreateNumpyTensor(filter_in_sizes)
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
convolver1 = nn_ops.Convolution(
input_shape=x1.shape,
filter_shape=filter_in.shape,
strides=[1, 1, 1],
padding="VALID")
self.assertEqual(convolver1.num_batch_dims, 1)
convolver2 = nn_ops.Convolution(
input_shape=x2.shape,
filter_shape=filter_in.shape,
strides=[1, 1, 1],
padding="VALID")
self.assertEqual(convolver2.num_batch_dims, 2)
conv1 = convolver1(x1, filter_in)
conv2 = convolver2(x2, filter_in)
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self):
tensor_in_sizes_batch = [10, 2, 3, 1, 3]
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3]
filter_in_sizes = [1, 1, 1, 3, 3]
filter_in = self._CreateNumpyTensor(filter_in_sizes)
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
conv1 = nn_ops.convolution(
x1, filter_in, strides=[1, 1, 1], padding="VALID")
conv2 = nn_ops.convolution(
x2, filter_in, strides=[1, 1, 1], padding="VALID")
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape))
def testConv3D1x1x1Filter(self):
expected_output = [
0.18518519, 0.22222222, 0.25925926, 0.40740741, 0.5, 0.59259259,

View File

@ -228,7 +228,7 @@ class _NonAtrousConvolution(object):
% data_format)
self.strides = strides
self.data_format = data_format
self.conv_op = gen_nn_ops.conv3d
self.conv_op = _conv3d_expanded_batch
# Note that we need this adapter since argument names for conv1d don't match
# those for gen_nn_ops.conv2d and gen_nn_ops.conv3d.
@ -241,7 +241,6 @@ class _NonAtrousConvolution(object):
padding=padding,
data_format=data_format,
name=name)
# pylint: enable=redefined-builtin
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
@ -1104,7 +1103,7 @@ def convolution_internal(
scope = "convolution"
with ops.name_scope(name, scope, [input, filters]) as name:
conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: gen_nn_ops.conv3d}
conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: _conv3d_expanded_batch}
if device_context.enclosing_tpu_context() is not None or all(
i == 1 for i in dilations):
@ -1783,40 +1782,42 @@ def conv1d(
name=None,
input=None, # pylint: disable=redefined-builtin
dilations=None):
r"""Computes a 1-D convolution given 3-D input and filter tensors.
r"""Computes a 1-D convolution of input with rank `>=3` and a `3-D` filter.
Given an input tensor of shape
[batch, in_width, in_channels]
if data_format is "NWC", or
[batch, in_channels, in_width]
if data_format is "NCW",
`batch_shape + [in_width, in_channels]`
if `data_format` is `"NWC"`, or
`batch_shape + [in_channels, in_width]`
if `data_format` is `"NCW"`,
and a filter / kernel tensor of shape
[filter_width, in_channels, out_channels], this op reshapes
the arguments to pass them to conv2d to perform the equivalent
`[filter_width, in_channels, out_channels]`, this op reshapes
the arguments to pass them to `conv2d` to perform the equivalent
convolution operation.
Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
For example, if `data_format` does not start with "NC", a tensor of shape
[batch, in_width, in_channels]
`batch_shape + [in_width, in_channels]`
is reshaped to
[batch, 1, in_width, in_channels],
`batch_shape + [1, in_width, in_channels]`,
and the filter is reshaped to
[1, filter_width, in_channels, out_channels].
`[1, filter_width, in_channels, out_channels]`.
The result is then reshaped back to
[batch, out_width, out_channels]
`batch_shape + [out_width, out_channels]`
\(where out_width is a function of the stride and padding as in conv2d\) and
returned to the caller.
Args:
value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
filters: A 3D `Tensor`. Must have the same type as `value`.
value: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or
`float64`.
filters: A Tensor of rank at least 3. Must have the same type as `value`.
stride: An int or list of `ints` that has length `1` or `3`. The number of
entries by which the filter is moved right at each step.
padding: 'SAME' or 'VALID'
use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
data_format: An optional `string` from `"NWC", "NCW"`. Defaults to `"NWC"`,
the data is stored in the order of [batch, in_width, in_channels]. The
`"NCW"` format stores data as [batch, in_channels, in_width].
the data is stored in the order of `batch_shape + [in_width,
in_channels]`. The `"NCW"` format stores data as `batch_shape +
[in_channels, in_width]`.
name: A name for the operation (optional).
input: Alias for value.
dilations: An int or list of `ints` that has length `1` or `3` which
@ -1832,14 +1833,14 @@ def conv1d(
"""
value = deprecation.deprecated_argument_lookup("input", input, "value", value)
with ops.name_scope(name, "conv1d", [value, filters]) as name:
# Reshape the input tensor to [batch, 1, in_width, in_channels]
# Reshape the input tensor to batch_shape + [1, in_width, in_channels]
if data_format is None or data_format == "NHWC" or data_format == "NWC":
data_format = "NHWC"
spatial_start_dim = 1
spatial_start_dim = -3
channel_index = 2
elif data_format == "NCHW" or data_format == "NCW":
data_format = "NCHW"
spatial_start_dim = 2
spatial_start_dim = -2
channel_index = 1
else:
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
@ -1848,15 +1849,30 @@ def conv1d(
value = array_ops.expand_dims(value, spatial_start_dim)
filters = array_ops.expand_dims(filters, 0)
result = gen_nn_ops.conv2d(
value,
filters,
strides,
padding,
use_cudnn_on_gpu=use_cudnn_on_gpu,
data_format=data_format,
dilations=dilations,
name=name)
if value.shape.ndims in (4, 3, 2, 1, 0, None):
result = gen_nn_ops.conv2d(
value,
filters,
strides,
padding,
use_cudnn_on_gpu=use_cudnn_on_gpu,
data_format=data_format,
dilations=dilations,
name=name)
else:
result = squeeze_batch_dims(
value,
functools.partial(
gen_nn_ops.conv2d,
filter=filters,
strides=strides,
padding=padding,
use_cudnn_on_gpu=use_cudnn_on_gpu,
data_format=data_format,
dilations=dilations,
),
inner_rank=3,
name=name)
return array_ops.squeeze(result, [spatial_start_dim])
@ -1873,36 +1889,38 @@ def conv1d_v2(
r"""Computes a 1-D convolution given 3-D input and filter tensors.
Given an input tensor of shape
[batch, in_width, in_channels]
if data_format is "NWC", or
[batch, in_channels, in_width]
if data_format is "NCW",
`batch_shape + [in_width, in_channels]`
if `data_format` is `"NWC"`, or
`batch_shape + [in_channels, in_width]`
if `data_format` is `"NCW"`,
and a filter / kernel tensor of shape
[filter_width, in_channels, out_channels], this op reshapes
the arguments to pass them to conv2d to perform the equivalent
`[filter_width, in_channels, out_channels]`, this op reshapes
the arguments to pass them to `conv2d` to perform the equivalent
convolution operation.
Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
For example, if `data_format` does not start with "NC", a tensor of shape
[batch, in_width, in_channels]
For example, if `data_format` does not start with `"NC"`, a tensor of shape
`batch_shape + [in_width, in_channels]`
is reshaped to
[batch, 1, in_width, in_channels],
`batch_shape + [1, in_width, in_channels]`,
and the filter is reshaped to
[1, filter_width, in_channels, out_channels].
`[1, filter_width, in_channels, out_channels]`.
The result is then reshaped back to
[batch, out_width, out_channels]
`batch_shape + [out_width, out_channels]`
\(where out_width is a function of the stride and padding as in conv2d\) and
returned to the caller.
Args:
input: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
filters: A 3D `Tensor`. Must have the same type as `input`.
input: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or
`float64`.
filters: A Tensor of rank at least 3. Must have the same type as `input`.
stride: An int or list of `ints` that has length `1` or `3`. The number of
entries by which the filter is moved right at each step.
padding: 'SAME' or 'VALID'
data_format: An optional `string` from `"NWC", "NCW"`. Defaults to `"NWC"`,
the data is stored in the order of [batch, in_width, in_channels]. The
`"NCW"` format stores data as [batch, in_channels, in_width].
the data is stored in the order of
`batch_shape + [in_width, in_channels]`. The `"NCW"` format stores data
as `batch_shape + [in_channels, in_width]`.
dilations: An int or list of `ints` that has length `1` or `3` which
defaults to 1. The dilation factor for each dimension of input. If set to
k > 1, there will be k-1 skipped cells between each filter element on that
@ -2071,9 +2089,9 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin
Args:
input: A `Tensor`. Must be one of the following types:
`half`, `bfloat16`, `float32`, `float64`.
A 4+-D tensor. The dimension order is interpreted according to the value
of `data_format`; with the all-but-inner-3 dimensions acting as batch
dimensions. See below for details.
A Tensor of rank at least 4. The dimension order is interpreted according
to the value of `data_format`; with the all-but-inner-3 dimensions acting
as batch dimensions. See below for details.
filters: A `Tensor`. Must have the same type as `input`.
A 4-D tensor of shape
`[filter_height, filter_width, in_channels, out_channels]`
@ -2214,6 +2232,8 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value
ndims = getattr(shape, "ndims", -1)
if ndims == -1: ndims = len(shape)
if ndims in (4, 3, 2, 1, 0, None):
# We avoid calling squeeze_batch_dims to reduce extra python function
# call slowdown in eager mode. This branch doesn't require reshapes.
return gen_nn_ops.conv2d(
input,
filter=filter,
@ -2535,6 +2555,8 @@ def _conv2d_expanded_batch(
ndims = getattr(shape, "ndims", -1)
if ndims == -1: ndims = len(shape)
if ndims in (4, 3, 2, 1, 0, None):
# We avoid calling squeeze_batch_dims to reduce extra python function
# call slowdown in eager mode. This branch doesn't require reshapes.
return gen_nn_ops.conv2d(
input,
filter=filters,
@ -2931,6 +2953,46 @@ def depthwise_conv2d_native_backprop_filter( # pylint: disable=redefined-builti
name=name)
def _conv3d_expanded_batch(
input, # pylint: disable=redefined-builtin
filter, # pylint: disable=redefined-builtin
strides,
padding,
data_format,
dilations=None,
name=None):
"""Helper function for `conv3d`; handles expanded batches."""
# Try really hard to avoid modifying the legacy name sceops - return early.
shape = getattr(input, "shape", None)
if shape is not None:
ndims = getattr(shape, "ndims", -1)
if ndims == -1:
ndims = len(shape)
if ndims in (5, 4, 3, 2, 1, 0, None):
# We avoid calling squeeze_batch_dims to reduce extra python function
# call slowdown in eager mode. This branch doesn't require reshapes.
return gen_nn_ops.conv3d(
input,
filter,
strides,
padding,
data_format=data_format,
dilations=dilations,
name=name)
else:
return squeeze_batch_dims(
input,
functools.partial(
gen_nn_ops.conv3d,
filter=filter,
strides=strides,
padding=padding,
data_format=data_format,
dilations=dilations),
inner_rank=4,
name=name)
@tf_export("nn.conv3d", v1=[])
@dispatch.add_dispatch_support
def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring
@ -2942,13 +3004,8 @@ def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring
name=None):
if dilations is None:
dilations = [1, 1, 1, 1, 1]
return gen_nn_ops.conv3d(input,
filters,
strides,
padding,
data_format=data_format,
dilations=dilations,
name=name)
return _conv3d_expanded_batch(input, filters, strides, padding, data_format,
dilations, name)
@tf_export(v1=["nn.conv3d"])