[TF] Add extra batch dim support to tf.nn.conv{1,3}d and keras.layers.Conv{1,3}D.
This extends the support just added for tf.nn.conv2d and keras.layers.Conv2D. PiperOrigin-RevId: 314738423 Change-Id: I588c36e493d1f41a67b9721e2cbd84c564277f44
This commit is contained in:
parent
47f7695c00
commit
549e69ca13
|
@ -228,7 +228,7 @@ class Conv(Layer):
|
|||
|
||||
# Apply causal padding to inputs for Conv1D.
|
||||
if self.padding == 'causal' and self.__class__.__name__ == 'Conv1D':
|
||||
inputs = array_ops.pad(inputs, self._compute_causal_padding())
|
||||
inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs))
|
||||
|
||||
outputs = self._convolution_op(inputs, self.kernel)
|
||||
|
||||
|
@ -324,13 +324,17 @@ class Conv(Layer):
|
|||
base_config = super(Conv, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def _compute_causal_padding(self):
|
||||
def _compute_causal_padding(self, inputs):
|
||||
"""Calculates padding for 'causal' option for 1-d conv layers."""
|
||||
left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)
|
||||
if self.data_format == 'channels_last':
|
||||
causal_padding = [[0, 0], [left_pad, 0], [0, 0]]
|
||||
if getattr(inputs.shape, 'ndims', None) is None:
|
||||
batch_rank = 1
|
||||
else:
|
||||
causal_padding = [[0, 0], [0, 0], [left_pad, 0]]
|
||||
batch_rank = len(inputs.shape) - 2
|
||||
if self.data_format == 'channels_last':
|
||||
causal_padding = [[0, 0]] * batch_rank + [[left_pad, 0], [0, 0]]
|
||||
else:
|
||||
causal_padding = [[0, 0]] * batch_rank + [[0, 0], [left_pad, 0]]
|
||||
return causal_padding
|
||||
|
||||
def _get_channel_axis(self):
|
||||
|
@ -404,6 +408,16 @@ class Conv1D(Conv):
|
|||
>>> print(y.shape)
|
||||
(4, 8, 32)
|
||||
|
||||
>>> # With extended batch shape [4, 7] (e.g. weather data where batch
|
||||
>>> # dimensions correspond to spatial location and the third dimension
|
||||
>>> # corresponds to time.)
|
||||
>>> input_shape = (4, 7, 10, 128)
|
||||
>>> x = tf.random.normal(input_shape)
|
||||
>>> y = tf.keras.layers.Conv1D(
|
||||
... 32, 3, activation='relu', input_shape=input_shape[2:])(x)
|
||||
>>> print(y.shape)
|
||||
(4, 7, 8, 32)
|
||||
|
||||
Arguments:
|
||||
filters: Integer, the dimensionality of the output space
|
||||
(i.e. the number of output filters in the convolution).
|
||||
|
@ -451,10 +465,10 @@ class Conv1D(Conv):
|
|||
see `keras.constraints`).
|
||||
|
||||
Input shape:
|
||||
3D tensor with shape: `(batch_size, steps, input_dim)`
|
||||
3+D tensor with shape: `batch_shape + (steps, input_dim)`
|
||||
|
||||
Output shape:
|
||||
3D tensor with shape: `(batch_size, new_steps, filters)`
|
||||
3+D tensor with shape: `batch_shape + (new_steps, filters)`
|
||||
`steps` value might have changed due to padding or strides.
|
||||
|
||||
Returns:
|
||||
|
@ -462,7 +476,7 @@ class Conv1D(Conv):
|
|||
`activation(conv1d(inputs, kernel) + bias)`.
|
||||
|
||||
Raises:
|
||||
ValueError: when both `strides` > 1 and `dilation_rate` > 1.
|
||||
ValueError: when both `strides > 1` and `dilation_rate > 1`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -702,6 +716,15 @@ class Conv3D(Conv):
|
|||
>>> print(y.shape)
|
||||
(4, 26, 26, 26, 2)
|
||||
|
||||
>>> # With extended batch shape [4, 7], e.g. a batch of 4 videos of 3D frames,
|
||||
>>> # with 7 frames per video.
|
||||
>>> input_shape = (4, 7, 28, 28, 28, 1)
|
||||
>>> x = tf.random.normal(input_shape)
|
||||
>>> y = tf.keras.layers.Conv3D(
|
||||
... 2, 3, activation='relu', input_shape=input_shape[2:])(x)
|
||||
>>> print(y.shape)
|
||||
(4, 7, 26, 26, 26, 2)
|
||||
|
||||
Arguments:
|
||||
filters: Integer, the dimensionality of the output space
|
||||
(i.e. the number of output filters in the convolution).
|
||||
|
@ -721,9 +744,9 @@ class Conv3D(Conv):
|
|||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
|
||||
`batch_shape + (spatial_dim1, spatial_dim2, spatial_dim3, channels)`
|
||||
while `channels_first` corresponds to inputs with shape
|
||||
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
|
||||
`batch_shape + (channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
|
@ -760,30 +783,30 @@ class Conv3D(Conv):
|
|||
see `keras.constraints`).
|
||||
|
||||
Input shape:
|
||||
5D tensor with shape:
|
||||
`(batch_size, channels, conv_dim1, conv_dim2, conv_dim3)` if
|
||||
5+D tensor with shape:
|
||||
`batch_shape + (channels, conv_dim1, conv_dim2, conv_dim3)` if
|
||||
data_format='channels_first'
|
||||
or 5D tensor with shape:
|
||||
`(batch_size, conv_dim1, conv_dim2, conv_dim3, channels)` if
|
||||
or 5+D tensor with shape:
|
||||
`batch_shape + (conv_dim1, conv_dim2, conv_dim3, channels)` if
|
||||
data_format='channels_last'.
|
||||
|
||||
Output shape:
|
||||
5D tensor with shape:
|
||||
`(batch_size, filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if
|
||||
5+D tensor with shape:
|
||||
`batch_shape + (filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if
|
||||
data_format='channels_first'
|
||||
or 5D tensor with shape:
|
||||
`(batch_size, new_conv_dim1, new_conv_dim2, new_conv_dim3, filters)` if
|
||||
or 5+D tensor with shape:
|
||||
`batch_shape + (new_conv_dim1, new_conv_dim2, new_conv_dim3, filters)` if
|
||||
data_format='channels_last'.
|
||||
`new_conv_dim1`, `new_conv_dim2` and `new_conv_dim3` values might have
|
||||
changed due to padding.
|
||||
|
||||
Returns:
|
||||
A tensor of rank 5 representing
|
||||
A tensor of rank 5+ representing
|
||||
`activation(conv3d(inputs, kernel) + bias)`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `padding` is "causal".
|
||||
ValueError: when both `strides` > 1 and `dilation_rate` > 1.
|
||||
ValueError: when both `strides > 1` and `dilation_rate > 1`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
@ -49,6 +49,21 @@ class Conv1DTest(keras_parameterized.TestCase):
|
|||
input_shape=(num_samples, length, stack_size),
|
||||
expected_output_shape=expected_output_shape)
|
||||
|
||||
def _run_test_extra_batch_dim(self, kwargs, expected_output_shape):
|
||||
batch_shape = (2, 11)
|
||||
stack_size = 3
|
||||
length = 7
|
||||
|
||||
with self.cached_session(use_gpu=True):
|
||||
if expected_output_shape is not None:
|
||||
expected_output_shape = (None,) + expected_output_shape
|
||||
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Conv1D,
|
||||
kwargs=kwargs,
|
||||
input_shape=batch_shape + (length, stack_size),
|
||||
expected_output_shape=expected_output_shape)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('padding_valid', {
|
||||
'padding': 'valid'
|
||||
|
@ -85,6 +100,7 @@ class Conv1DTest(keras_parameterized.TestCase):
|
|||
kwargs['kernel_size'] = 3
|
||||
if not requires_gpu or test.is_gpu_available(cuda_only=True):
|
||||
self._run_test(kwargs, expected_output_shape)
|
||||
self._run_test_extra_batch_dim(kwargs, expected_output_shape)
|
||||
|
||||
def test_conv1d_regularizers(self):
|
||||
kwargs = {
|
||||
|
@ -281,6 +297,27 @@ class Conv3DTest(keras_parameterized.TestCase):
|
|||
expected_output_shape=expected_output_shape,
|
||||
validate_training=validate_training)
|
||||
|
||||
def _run_test_extra_batch_dim(self,
|
||||
kwargs,
|
||||
expected_output_shape,
|
||||
validate_training=True):
|
||||
batch_shape = (2, 11)
|
||||
stack_size = 3
|
||||
num_row = 7
|
||||
num_col = 6
|
||||
depth = 5
|
||||
|
||||
with self.cached_session(use_gpu=True):
|
||||
if expected_output_shape is not None:
|
||||
expected_output_shape = (None,) + expected_output_shape
|
||||
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Conv3D,
|
||||
kwargs=kwargs,
|
||||
input_shape=batch_shape + (depth, num_row, num_col, stack_size),
|
||||
expected_output_shape=expected_output_shape,
|
||||
validate_training=validate_training)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('padding_valid', {
|
||||
'padding': 'valid'
|
||||
|
@ -313,6 +350,8 @@ class Conv3DTest(keras_parameterized.TestCase):
|
|||
test_training = 'groups' not in kwargs or not test_util.is_xla_enabled()
|
||||
if not requires_gpu or test.is_gpu_available(cuda_only=True):
|
||||
self._run_test(kwargs, expected_output_shape, test_training)
|
||||
self._run_test_extra_batch_dim(kwargs, expected_output_shape,
|
||||
test_training)
|
||||
|
||||
def test_conv3d_regularizers(self):
|
||||
kwargs = {
|
||||
|
|
|
@ -55,6 +55,32 @@ class Conv1DTest(test.TestCase):
|
|||
self.assertEqual(len(output), 2)
|
||||
self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4])
|
||||
|
||||
def testExpandedBatch(self):
|
||||
"""Test that argument passing to conv1d is handled properly."""
|
||||
# double datatype is currently not supported for convolution ops
|
||||
# on the ROCm platform
|
||||
x = constant_op.constant([1, 2, 3, 4], dtype=dtypes.float32)
|
||||
x = array_ops.expand_dims(x, 0) # Add batch dimension
|
||||
x = array_ops.expand_dims(x, 2) # And depth dimension
|
||||
x = array_ops.stack([x, x]) # Make batch shape [2, 1]
|
||||
filters = constant_op.constant([2, 1], dtype=dtypes.float32)
|
||||
filters = array_ops.expand_dims(filters, 1) # in_channels
|
||||
filters = array_ops.expand_dims(filters, 2) # out_channels
|
||||
# Filters is 2x1x1
|
||||
for stride in [1, 2]:
|
||||
with self.cached_session(use_gpu=test.is_gpu_available()):
|
||||
c = nn_ops.conv1d(x, filters, stride, padding="VALID")
|
||||
reduced = array_ops.squeeze(c) # Sequeeze out dims 1 and 3.
|
||||
output = self.evaluate(reduced)
|
||||
if stride == 1:
|
||||
self.assertAllClose(output,
|
||||
[[2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4],
|
||||
[2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4]])
|
||||
else:
|
||||
self.assertAllClose(
|
||||
output,
|
||||
[[2 * 1 + 1 * 2, 2 * 3 + 1 * 4], [2 * 1 + 1 * 2, 2 * 3 + 1 * 4]])
|
||||
|
||||
def testConv1DTranspose(self):
|
||||
with self.cached_session():
|
||||
stride = 2
|
||||
|
|
|
@ -66,12 +66,8 @@ class Conv3DTest(test.TestCase):
|
|||
|
||||
def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, stride,
|
||||
padding, data_format, dtype, use_gpu):
|
||||
total_size_tensor = 1
|
||||
total_size_filter = 1
|
||||
for s in tensor_in_sizes:
|
||||
total_size_tensor *= s
|
||||
for s in filter_in_sizes:
|
||||
total_size_filter *= s
|
||||
total_size_tensor = np.prod(tensor_in_sizes)
|
||||
total_size_filter = np.prod(filter_in_sizes)
|
||||
|
||||
# Initializes the input tensor with array containing numbers from 0 to 1.
|
||||
# We keep the input tensor values fairly small to avoid overflowing float16
|
||||
|
@ -126,12 +122,8 @@ class Conv3DTest(test.TestCase):
|
|||
def _ComputeReferenceDilatedConv(self, tensor_in_sizes, filter_in_sizes,
|
||||
stride, dilation, padding, data_format,
|
||||
use_gpu):
|
||||
total_size_tensor = 1
|
||||
total_size_filter = 1
|
||||
for s in tensor_in_sizes:
|
||||
total_size_tensor *= s
|
||||
for s in filter_in_sizes:
|
||||
total_size_filter *= s
|
||||
total_size_tensor = np.prod(tensor_in_sizes)
|
||||
total_size_filter = np.prod(filter_in_sizes)
|
||||
|
||||
# Initializes the input tensor with array containing incrementing
|
||||
# numbers from 1.
|
||||
|
@ -196,6 +188,69 @@ class Conv3DTest(test.TestCase):
|
|||
self.assertAllClose(
|
||||
e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-6)
|
||||
|
||||
def _CreateNumpyTensor(self, sizes):
|
||||
return np.asarray([f * 1.0
|
||||
for f in range(1,
|
||||
np.prod(sizes) + 1)]).reshape(sizes)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConv3DExpandedBatch(self):
|
||||
tensor_in_sizes_batch = [10, 2, 3, 1, 3]
|
||||
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3]
|
||||
filter_in_sizes = [1, 1, 1, 3, 3]
|
||||
filter_in = self._CreateNumpyTensor(filter_in_sizes)
|
||||
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
|
||||
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
|
||||
conv1 = nn_ops.conv3d_v2(
|
||||
x1, filter_in, strides=[1, 1, 1, 1, 1], padding="VALID")
|
||||
conv2 = nn_ops.conv3d_v2(
|
||||
x2, filter_in, strides=[1, 1, 1, 1, 1], padding="VALID")
|
||||
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
|
||||
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
|
||||
self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConvolutionClass3DExpandedBatch(self):
|
||||
tensor_in_sizes_batch = [10, 2, 3, 1, 3]
|
||||
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3]
|
||||
filter_in_sizes = [1, 1, 1, 3, 3]
|
||||
filter_in = self._CreateNumpyTensor(filter_in_sizes)
|
||||
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
|
||||
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
|
||||
convolver1 = nn_ops.Convolution(
|
||||
input_shape=x1.shape,
|
||||
filter_shape=filter_in.shape,
|
||||
strides=[1, 1, 1],
|
||||
padding="VALID")
|
||||
self.assertEqual(convolver1.num_batch_dims, 1)
|
||||
convolver2 = nn_ops.Convolution(
|
||||
input_shape=x2.shape,
|
||||
filter_shape=filter_in.shape,
|
||||
strides=[1, 1, 1],
|
||||
padding="VALID")
|
||||
self.assertEqual(convolver2.num_batch_dims, 2)
|
||||
conv1 = convolver1(x1, filter_in)
|
||||
conv2 = convolver2(x2, filter_in)
|
||||
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
|
||||
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
|
||||
self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self):
|
||||
tensor_in_sizes_batch = [10, 2, 3, 1, 3]
|
||||
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3]
|
||||
filter_in_sizes = [1, 1, 1, 3, 3]
|
||||
filter_in = self._CreateNumpyTensor(filter_in_sizes)
|
||||
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
|
||||
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
|
||||
conv1 = nn_ops.convolution(
|
||||
x1, filter_in, strides=[1, 1, 1], padding="VALID")
|
||||
conv2 = nn_ops.convolution(
|
||||
x2, filter_in, strides=[1, 1, 1], padding="VALID")
|
||||
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
|
||||
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
|
||||
self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape))
|
||||
|
||||
def testConv3D1x1x1Filter(self):
|
||||
expected_output = [
|
||||
0.18518519, 0.22222222, 0.25925926, 0.40740741, 0.5, 0.59259259,
|
||||
|
|
|
@ -228,7 +228,7 @@ class _NonAtrousConvolution(object):
|
|||
% data_format)
|
||||
self.strides = strides
|
||||
self.data_format = data_format
|
||||
self.conv_op = gen_nn_ops.conv3d
|
||||
self.conv_op = _conv3d_expanded_batch
|
||||
|
||||
# Note that we need this adapter since argument names for conv1d don't match
|
||||
# those for gen_nn_ops.conv2d and gen_nn_ops.conv3d.
|
||||
|
@ -241,7 +241,6 @@ class _NonAtrousConvolution(object):
|
|||
padding=padding,
|
||||
data_format=data_format,
|
||||
name=name)
|
||||
|
||||
# pylint: enable=redefined-builtin
|
||||
|
||||
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
|
||||
|
@ -1104,7 +1103,7 @@ def convolution_internal(
|
|||
scope = "convolution"
|
||||
|
||||
with ops.name_scope(name, scope, [input, filters]) as name:
|
||||
conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: gen_nn_ops.conv3d}
|
||||
conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: _conv3d_expanded_batch}
|
||||
|
||||
if device_context.enclosing_tpu_context() is not None or all(
|
||||
i == 1 for i in dilations):
|
||||
|
@ -1783,40 +1782,42 @@ def conv1d(
|
|||
name=None,
|
||||
input=None, # pylint: disable=redefined-builtin
|
||||
dilations=None):
|
||||
r"""Computes a 1-D convolution given 3-D input and filter tensors.
|
||||
r"""Computes a 1-D convolution of input with rank `>=3` and a `3-D` filter.
|
||||
|
||||
Given an input tensor of shape
|
||||
[batch, in_width, in_channels]
|
||||
if data_format is "NWC", or
|
||||
[batch, in_channels, in_width]
|
||||
if data_format is "NCW",
|
||||
`batch_shape + [in_width, in_channels]`
|
||||
if `data_format` is `"NWC"`, or
|
||||
`batch_shape + [in_channels, in_width]`
|
||||
if `data_format` is `"NCW"`,
|
||||
and a filter / kernel tensor of shape
|
||||
[filter_width, in_channels, out_channels], this op reshapes
|
||||
the arguments to pass them to conv2d to perform the equivalent
|
||||
`[filter_width, in_channels, out_channels]`, this op reshapes
|
||||
the arguments to pass them to `conv2d` to perform the equivalent
|
||||
convolution operation.
|
||||
|
||||
Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
|
||||
For example, if `data_format` does not start with "NC", a tensor of shape
|
||||
[batch, in_width, in_channels]
|
||||
`batch_shape + [in_width, in_channels]`
|
||||
is reshaped to
|
||||
[batch, 1, in_width, in_channels],
|
||||
`batch_shape + [1, in_width, in_channels]`,
|
||||
and the filter is reshaped to
|
||||
[1, filter_width, in_channels, out_channels].
|
||||
`[1, filter_width, in_channels, out_channels]`.
|
||||
The result is then reshaped back to
|
||||
[batch, out_width, out_channels]
|
||||
`batch_shape + [out_width, out_channels]`
|
||||
\(where out_width is a function of the stride and padding as in conv2d\) and
|
||||
returned to the caller.
|
||||
|
||||
Args:
|
||||
value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
|
||||
filters: A 3D `Tensor`. Must have the same type as `value`.
|
||||
value: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or
|
||||
`float64`.
|
||||
filters: A Tensor of rank at least 3. Must have the same type as `value`.
|
||||
stride: An int or list of `ints` that has length `1` or `3`. The number of
|
||||
entries by which the filter is moved right at each step.
|
||||
padding: 'SAME' or 'VALID'
|
||||
use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
|
||||
data_format: An optional `string` from `"NWC", "NCW"`. Defaults to `"NWC"`,
|
||||
the data is stored in the order of [batch, in_width, in_channels]. The
|
||||
`"NCW"` format stores data as [batch, in_channels, in_width].
|
||||
the data is stored in the order of `batch_shape + [in_width,
|
||||
in_channels]`. The `"NCW"` format stores data as `batch_shape +
|
||||
[in_channels, in_width]`.
|
||||
name: A name for the operation (optional).
|
||||
input: Alias for value.
|
||||
dilations: An int or list of `ints` that has length `1` or `3` which
|
||||
|
@ -1832,14 +1833,14 @@ def conv1d(
|
|||
"""
|
||||
value = deprecation.deprecated_argument_lookup("input", input, "value", value)
|
||||
with ops.name_scope(name, "conv1d", [value, filters]) as name:
|
||||
# Reshape the input tensor to [batch, 1, in_width, in_channels]
|
||||
# Reshape the input tensor to batch_shape + [1, in_width, in_channels]
|
||||
if data_format is None or data_format == "NHWC" or data_format == "NWC":
|
||||
data_format = "NHWC"
|
||||
spatial_start_dim = 1
|
||||
spatial_start_dim = -3
|
||||
channel_index = 2
|
||||
elif data_format == "NCHW" or data_format == "NCW":
|
||||
data_format = "NCHW"
|
||||
spatial_start_dim = 2
|
||||
spatial_start_dim = -2
|
||||
channel_index = 1
|
||||
else:
|
||||
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
|
||||
|
@ -1848,15 +1849,30 @@ def conv1d(
|
|||
|
||||
value = array_ops.expand_dims(value, spatial_start_dim)
|
||||
filters = array_ops.expand_dims(filters, 0)
|
||||
result = gen_nn_ops.conv2d(
|
||||
value,
|
||||
filters,
|
||||
strides,
|
||||
padding,
|
||||
use_cudnn_on_gpu=use_cudnn_on_gpu,
|
||||
data_format=data_format,
|
||||
dilations=dilations,
|
||||
name=name)
|
||||
if value.shape.ndims in (4, 3, 2, 1, 0, None):
|
||||
result = gen_nn_ops.conv2d(
|
||||
value,
|
||||
filters,
|
||||
strides,
|
||||
padding,
|
||||
use_cudnn_on_gpu=use_cudnn_on_gpu,
|
||||
data_format=data_format,
|
||||
dilations=dilations,
|
||||
name=name)
|
||||
else:
|
||||
result = squeeze_batch_dims(
|
||||
value,
|
||||
functools.partial(
|
||||
gen_nn_ops.conv2d,
|
||||
filter=filters,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
use_cudnn_on_gpu=use_cudnn_on_gpu,
|
||||
data_format=data_format,
|
||||
dilations=dilations,
|
||||
),
|
||||
inner_rank=3,
|
||||
name=name)
|
||||
return array_ops.squeeze(result, [spatial_start_dim])
|
||||
|
||||
|
||||
|
@ -1873,36 +1889,38 @@ def conv1d_v2(
|
|||
r"""Computes a 1-D convolution given 3-D input and filter tensors.
|
||||
|
||||
Given an input tensor of shape
|
||||
[batch, in_width, in_channels]
|
||||
if data_format is "NWC", or
|
||||
[batch, in_channels, in_width]
|
||||
if data_format is "NCW",
|
||||
`batch_shape + [in_width, in_channels]`
|
||||
if `data_format` is `"NWC"`, or
|
||||
`batch_shape + [in_channels, in_width]`
|
||||
if `data_format` is `"NCW"`,
|
||||
and a filter / kernel tensor of shape
|
||||
[filter_width, in_channels, out_channels], this op reshapes
|
||||
the arguments to pass them to conv2d to perform the equivalent
|
||||
`[filter_width, in_channels, out_channels]`, this op reshapes
|
||||
the arguments to pass them to `conv2d` to perform the equivalent
|
||||
convolution operation.
|
||||
|
||||
Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
|
||||
For example, if `data_format` does not start with "NC", a tensor of shape
|
||||
[batch, in_width, in_channels]
|
||||
For example, if `data_format` does not start with `"NC"`, a tensor of shape
|
||||
`batch_shape + [in_width, in_channels]`
|
||||
is reshaped to
|
||||
[batch, 1, in_width, in_channels],
|
||||
`batch_shape + [1, in_width, in_channels]`,
|
||||
and the filter is reshaped to
|
||||
[1, filter_width, in_channels, out_channels].
|
||||
`[1, filter_width, in_channels, out_channels]`.
|
||||
The result is then reshaped back to
|
||||
[batch, out_width, out_channels]
|
||||
`batch_shape + [out_width, out_channels]`
|
||||
\(where out_width is a function of the stride and padding as in conv2d\) and
|
||||
returned to the caller.
|
||||
|
||||
Args:
|
||||
input: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
|
||||
filters: A 3D `Tensor`. Must have the same type as `input`.
|
||||
input: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or
|
||||
`float64`.
|
||||
filters: A Tensor of rank at least 3. Must have the same type as `input`.
|
||||
stride: An int or list of `ints` that has length `1` or `3`. The number of
|
||||
entries by which the filter is moved right at each step.
|
||||
padding: 'SAME' or 'VALID'
|
||||
data_format: An optional `string` from `"NWC", "NCW"`. Defaults to `"NWC"`,
|
||||
the data is stored in the order of [batch, in_width, in_channels]. The
|
||||
`"NCW"` format stores data as [batch, in_channels, in_width].
|
||||
the data is stored in the order of
|
||||
`batch_shape + [in_width, in_channels]`. The `"NCW"` format stores data
|
||||
as `batch_shape + [in_channels, in_width]`.
|
||||
dilations: An int or list of `ints` that has length `1` or `3` which
|
||||
defaults to 1. The dilation factor for each dimension of input. If set to
|
||||
k > 1, there will be k-1 skipped cells between each filter element on that
|
||||
|
@ -2071,9 +2089,9 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin
|
|||
Args:
|
||||
input: A `Tensor`. Must be one of the following types:
|
||||
`half`, `bfloat16`, `float32`, `float64`.
|
||||
A 4+-D tensor. The dimension order is interpreted according to the value
|
||||
of `data_format`; with the all-but-inner-3 dimensions acting as batch
|
||||
dimensions. See below for details.
|
||||
A Tensor of rank at least 4. The dimension order is interpreted according
|
||||
to the value of `data_format`; with the all-but-inner-3 dimensions acting
|
||||
as batch dimensions. See below for details.
|
||||
filters: A `Tensor`. Must have the same type as `input`.
|
||||
A 4-D tensor of shape
|
||||
`[filter_height, filter_width, in_channels, out_channels]`
|
||||
|
@ -2214,6 +2232,8 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value
|
|||
ndims = getattr(shape, "ndims", -1)
|
||||
if ndims == -1: ndims = len(shape)
|
||||
if ndims in (4, 3, 2, 1, 0, None):
|
||||
# We avoid calling squeeze_batch_dims to reduce extra python function
|
||||
# call slowdown in eager mode. This branch doesn't require reshapes.
|
||||
return gen_nn_ops.conv2d(
|
||||
input,
|
||||
filter=filter,
|
||||
|
@ -2535,6 +2555,8 @@ def _conv2d_expanded_batch(
|
|||
ndims = getattr(shape, "ndims", -1)
|
||||
if ndims == -1: ndims = len(shape)
|
||||
if ndims in (4, 3, 2, 1, 0, None):
|
||||
# We avoid calling squeeze_batch_dims to reduce extra python function
|
||||
# call slowdown in eager mode. This branch doesn't require reshapes.
|
||||
return gen_nn_ops.conv2d(
|
||||
input,
|
||||
filter=filters,
|
||||
|
@ -2931,6 +2953,46 @@ def depthwise_conv2d_native_backprop_filter( # pylint: disable=redefined-builti
|
|||
name=name)
|
||||
|
||||
|
||||
def _conv3d_expanded_batch(
|
||||
input, # pylint: disable=redefined-builtin
|
||||
filter, # pylint: disable=redefined-builtin
|
||||
strides,
|
||||
padding,
|
||||
data_format,
|
||||
dilations=None,
|
||||
name=None):
|
||||
"""Helper function for `conv3d`; handles expanded batches."""
|
||||
# Try really hard to avoid modifying the legacy name sceops - return early.
|
||||
shape = getattr(input, "shape", None)
|
||||
if shape is not None:
|
||||
ndims = getattr(shape, "ndims", -1)
|
||||
if ndims == -1:
|
||||
ndims = len(shape)
|
||||
if ndims in (5, 4, 3, 2, 1, 0, None):
|
||||
# We avoid calling squeeze_batch_dims to reduce extra python function
|
||||
# call slowdown in eager mode. This branch doesn't require reshapes.
|
||||
return gen_nn_ops.conv3d(
|
||||
input,
|
||||
filter,
|
||||
strides,
|
||||
padding,
|
||||
data_format=data_format,
|
||||
dilations=dilations,
|
||||
name=name)
|
||||
else:
|
||||
return squeeze_batch_dims(
|
||||
input,
|
||||
functools.partial(
|
||||
gen_nn_ops.conv3d,
|
||||
filter=filter,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
dilations=dilations),
|
||||
inner_rank=4,
|
||||
name=name)
|
||||
|
||||
|
||||
@tf_export("nn.conv3d", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring
|
||||
|
@ -2942,13 +3004,8 @@ def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring
|
|||
name=None):
|
||||
if dilations is None:
|
||||
dilations = [1, 1, 1, 1, 1]
|
||||
return gen_nn_ops.conv3d(input,
|
||||
filters,
|
||||
strides,
|
||||
padding,
|
||||
data_format=data_format,
|
||||
dilations=dilations,
|
||||
name=name)
|
||||
return _conv3d_expanded_batch(input, filters, strides, padding, data_format,
|
||||
dilations, name)
|
||||
|
||||
|
||||
@tf_export(v1=["nn.conv3d"])
|
||||
|
|
Loading…
Reference in New Issue