Standardise inputs for the conv1d and add the ability to perform dilations
PiperOrigin-RevId: 232507407
This commit is contained in:
parent
f6a39be9f4
commit
e1f2db44f4
@ -3999,14 +3999,16 @@ def fractional_avg_pool_v2(value,
|
||||
"`NHWC` for data_format is deprecated, use `NWC` instead",
|
||||
warn_once=True,
|
||||
data_format="NHWC")
|
||||
def conv1d(value=None,
|
||||
filters=None,
|
||||
stride=None,
|
||||
padding=None,
|
||||
def conv1d(
|
||||
value,
|
||||
filters,
|
||||
stride,
|
||||
padding,
|
||||
use_cudnn_on_gpu=None,
|
||||
data_format=None,
|
||||
name=None,
|
||||
input=None): # pylint: disable=redefined-builtin
|
||||
input=None, # pylint: disable=redefined-builtin
|
||||
dilations=None):
|
||||
r"""Computes a 1-D convolution given 3-D input and filter tensors.
|
||||
|
||||
Given an input tensor of shape
|
||||
@ -4034,8 +4036,8 @@ def conv1d(value=None,
|
||||
Args:
|
||||
value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
|
||||
filters: A 3D `Tensor`. Must have the same type as `value`.
|
||||
stride: An `integer`. The number of entries by which
|
||||
the filter is moved right at each step.
|
||||
stride: An int or list of `ints` that has length `1` or `3`. The number of
|
||||
entries by which the filter is moved right at each step.
|
||||
padding: 'SAME' or 'VALID'
|
||||
use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
|
||||
data_format: An optional `string` from `"NWC", "NCW"`. Defaults
|
||||
@ -4044,6 +4046,10 @@ def conv1d(value=None,
|
||||
data as [batch, in_channels, in_width].
|
||||
name: A name for the operation (optional).
|
||||
input: Alias for value.
|
||||
dilations: An int or list of `ints` that has length `1` or `3` which
|
||||
defaults to 1. The dilation factor for each dimension of input. If set to
|
||||
k > 1, there will be k-1 skipped cells between each filter element on that
|
||||
dimension. Dilations in the batch and depth dimensions must be 1.
|
||||
|
||||
Returns:
|
||||
A `Tensor`. Has the same type as input.
|
||||
@ -4057,13 +4063,16 @@ def conv1d(value=None,
|
||||
if data_format is None or data_format == "NHWC" or data_format == "NWC":
|
||||
data_format = "NHWC"
|
||||
spatial_start_dim = 1
|
||||
strides = [1, 1, stride, 1]
|
||||
channel_index = 2
|
||||
elif data_format == "NCHW" or data_format == "NCW":
|
||||
data_format = "NCHW"
|
||||
spatial_start_dim = 2
|
||||
strides = [1, 1, 1, stride]
|
||||
channel_index = 1
|
||||
else:
|
||||
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
|
||||
strides = [1] + _get_sequence(stride, 1, channel_index, "stride")
|
||||
dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
|
||||
|
||||
value = array_ops.expand_dims(value, spatial_start_dim)
|
||||
filters = array_ops.expand_dims(filters, 0)
|
||||
result = gen_nn_ops.conv2d(
|
||||
@ -4072,16 +4081,20 @@ def conv1d(value=None,
|
||||
strides,
|
||||
padding,
|
||||
use_cudnn_on_gpu=use_cudnn_on_gpu,
|
||||
data_format=data_format)
|
||||
data_format=data_format,
|
||||
dilations=dilations,
|
||||
name=name)
|
||||
return array_ops.squeeze(result, [spatial_start_dim])
|
||||
|
||||
|
||||
@tf_export("nn.conv1d", v1=[])
|
||||
def conv1d_v2(input, # pylint: disable=redefined-builtin
|
||||
def conv1d_v2(
|
||||
input, # pylint: disable=redefined-builtin
|
||||
filters,
|
||||
stride,
|
||||
padding,
|
||||
data_format=None,
|
||||
data_format="NWC",
|
||||
dilations=None,
|
||||
name=None):
|
||||
r"""Computes a 1-D convolution given 3-D input and filter tensors.
|
||||
|
||||
@ -4110,13 +4123,17 @@ def conv1d_v2(input, # pylint: disable=redefined-builtin
|
||||
Args:
|
||||
input: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
|
||||
filters: A 3D `Tensor`. Must have the same type as `input`.
|
||||
stride: An `integer`. The number of entries by which
|
||||
the filter is moved right at each step.
|
||||
stride: An int or list of `ints` that has length `1` or `3`. The number of
|
||||
entries by which the filter is moved right at each step.
|
||||
padding: 'SAME' or 'VALID'
|
||||
data_format: An optional `string` from `"NWC", "NCW"`. Defaults
|
||||
to `"NWC"`, the data is stored in the order of
|
||||
[batch, in_width, in_channels]. The `"NCW"` format stores
|
||||
data as [batch, in_channels, in_width].
|
||||
dilations: An int or list of `ints` that has length `1` or `3` which
|
||||
defaults to 1. The dilation factor for each dimension of input. If set to
|
||||
k > 1, there will be k-1 skipped cells between each filter element on that
|
||||
dimension. Dilations in the batch and depth dimensions must be 1.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -4125,13 +4142,15 @@ def conv1d_v2(input, # pylint: disable=redefined-builtin
|
||||
Raises:
|
||||
ValueError: if `data_format` is invalid.
|
||||
"""
|
||||
return conv1d(input, # pylint: disable=redefined-builtin
|
||||
return conv1d(
|
||||
input, # pylint: disable=redefined-builtin
|
||||
filters,
|
||||
stride,
|
||||
padding,
|
||||
use_cudnn_on_gpu=True,
|
||||
data_format=data_format,
|
||||
name=name)
|
||||
name=name,
|
||||
dilations=dilations)
|
||||
|
||||
|
||||
def conv1d_transpose(
|
||||
|
@ -54,7 +54,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "conv1d"
|
||||
argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\', \'dilations\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "conv2d"
|
||||
|
@ -50,7 +50,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "conv1d"
|
||||
argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NWC\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "conv2d"
|
||||
|
Loading…
Reference in New Issue
Block a user