Standardise inputs for the conv1d and add the ability to perform dilations
PiperOrigin-RevId: 232507407
This commit is contained in:
parent
f6a39be9f4
commit
e1f2db44f4
@ -3999,14 +3999,16 @@ def fractional_avg_pool_v2(value,
|
|||||||
"`NHWC` for data_format is deprecated, use `NWC` instead",
|
"`NHWC` for data_format is deprecated, use `NWC` instead",
|
||||||
warn_once=True,
|
warn_once=True,
|
||||||
data_format="NHWC")
|
data_format="NHWC")
|
||||||
def conv1d(value=None,
|
def conv1d(
|
||||||
filters=None,
|
value,
|
||||||
stride=None,
|
filters,
|
||||||
padding=None,
|
stride,
|
||||||
use_cudnn_on_gpu=None,
|
padding,
|
||||||
data_format=None,
|
use_cudnn_on_gpu=None,
|
||||||
name=None,
|
data_format=None,
|
||||||
input=None): # pylint: disable=redefined-builtin
|
name=None,
|
||||||
|
input=None, # pylint: disable=redefined-builtin
|
||||||
|
dilations=None):
|
||||||
r"""Computes a 1-D convolution given 3-D input and filter tensors.
|
r"""Computes a 1-D convolution given 3-D input and filter tensors.
|
||||||
|
|
||||||
Given an input tensor of shape
|
Given an input tensor of shape
|
||||||
@ -4034,8 +4036,8 @@ def conv1d(value=None,
|
|||||||
Args:
|
Args:
|
||||||
value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
|
value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
|
||||||
filters: A 3D `Tensor`. Must have the same type as `value`.
|
filters: A 3D `Tensor`. Must have the same type as `value`.
|
||||||
stride: An `integer`. The number of entries by which
|
stride: An int or list of `ints` that has length `1` or `3`. The number of
|
||||||
the filter is moved right at each step.
|
entries by which the filter is moved right at each step.
|
||||||
padding: 'SAME' or 'VALID'
|
padding: 'SAME' or 'VALID'
|
||||||
use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
|
use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
|
||||||
data_format: An optional `string` from `"NWC", "NCW"`. Defaults
|
data_format: An optional `string` from `"NWC", "NCW"`. Defaults
|
||||||
@ -4044,6 +4046,10 @@ def conv1d(value=None,
|
|||||||
data as [batch, in_channels, in_width].
|
data as [batch, in_channels, in_width].
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
input: Alias for value.
|
input: Alias for value.
|
||||||
|
dilations: An int or list of `ints` that has length `1` or `3` which
|
||||||
|
defaults to 1. The dilation factor for each dimension of input. If set to
|
||||||
|
k > 1, there will be k-1 skipped cells between each filter element on that
|
||||||
|
dimension. Dilations in the batch and depth dimensions must be 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Tensor`. Has the same type as input.
|
A `Tensor`. Has the same type as input.
|
||||||
@ -4057,13 +4063,16 @@ def conv1d(value=None,
|
|||||||
if data_format is None or data_format == "NHWC" or data_format == "NWC":
|
if data_format is None or data_format == "NHWC" or data_format == "NWC":
|
||||||
data_format = "NHWC"
|
data_format = "NHWC"
|
||||||
spatial_start_dim = 1
|
spatial_start_dim = 1
|
||||||
strides = [1, 1, stride, 1]
|
channel_index = 2
|
||||||
elif data_format == "NCHW" or data_format == "NCW":
|
elif data_format == "NCHW" or data_format == "NCW":
|
||||||
data_format = "NCHW"
|
data_format = "NCHW"
|
||||||
spatial_start_dim = 2
|
spatial_start_dim = 2
|
||||||
strides = [1, 1, 1, stride]
|
channel_index = 1
|
||||||
else:
|
else:
|
||||||
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
|
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
|
||||||
|
strides = [1] + _get_sequence(stride, 1, channel_index, "stride")
|
||||||
|
dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
|
||||||
|
|
||||||
value = array_ops.expand_dims(value, spatial_start_dim)
|
value = array_ops.expand_dims(value, spatial_start_dim)
|
||||||
filters = array_ops.expand_dims(filters, 0)
|
filters = array_ops.expand_dims(filters, 0)
|
||||||
result = gen_nn_ops.conv2d(
|
result = gen_nn_ops.conv2d(
|
||||||
@ -4072,17 +4081,21 @@ def conv1d(value=None,
|
|||||||
strides,
|
strides,
|
||||||
padding,
|
padding,
|
||||||
use_cudnn_on_gpu=use_cudnn_on_gpu,
|
use_cudnn_on_gpu=use_cudnn_on_gpu,
|
||||||
data_format=data_format)
|
data_format=data_format,
|
||||||
|
dilations=dilations,
|
||||||
|
name=name)
|
||||||
return array_ops.squeeze(result, [spatial_start_dim])
|
return array_ops.squeeze(result, [spatial_start_dim])
|
||||||
|
|
||||||
|
|
||||||
@tf_export("nn.conv1d", v1=[])
|
@tf_export("nn.conv1d", v1=[])
|
||||||
def conv1d_v2(input, # pylint: disable=redefined-builtin
|
def conv1d_v2(
|
||||||
filters,
|
input, # pylint: disable=redefined-builtin
|
||||||
stride,
|
filters,
|
||||||
padding,
|
stride,
|
||||||
data_format=None,
|
padding,
|
||||||
name=None):
|
data_format="NWC",
|
||||||
|
dilations=None,
|
||||||
|
name=None):
|
||||||
r"""Computes a 1-D convolution given 3-D input and filter tensors.
|
r"""Computes a 1-D convolution given 3-D input and filter tensors.
|
||||||
|
|
||||||
Given an input tensor of shape
|
Given an input tensor of shape
|
||||||
@ -4110,13 +4123,17 @@ def conv1d_v2(input, # pylint: disable=redefined-builtin
|
|||||||
Args:
|
Args:
|
||||||
input: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
|
input: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
|
||||||
filters: A 3D `Tensor`. Must have the same type as `input`.
|
filters: A 3D `Tensor`. Must have the same type as `input`.
|
||||||
stride: An `integer`. The number of entries by which
|
stride: An int or list of `ints` that has length `1` or `3`. The number of
|
||||||
the filter is moved right at each step.
|
entries by which the filter is moved right at each step.
|
||||||
padding: 'SAME' or 'VALID'
|
padding: 'SAME' or 'VALID'
|
||||||
data_format: An optional `string` from `"NWC", "NCW"`. Defaults
|
data_format: An optional `string` from `"NWC", "NCW"`. Defaults
|
||||||
to `"NWC"`, the data is stored in the order of
|
to `"NWC"`, the data is stored in the order of
|
||||||
[batch, in_width, in_channels]. The `"NCW"` format stores
|
[batch, in_width, in_channels]. The `"NCW"` format stores
|
||||||
data as [batch, in_channels, in_width].
|
data as [batch, in_channels, in_width].
|
||||||
|
dilations: An int or list of `ints` that has length `1` or `3` which
|
||||||
|
defaults to 1. The dilation factor for each dimension of input. If set to
|
||||||
|
k > 1, there will be k-1 skipped cells between each filter element on that
|
||||||
|
dimension. Dilations in the batch and depth dimensions must be 1.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -4125,13 +4142,15 @@ def conv1d_v2(input, # pylint: disable=redefined-builtin
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if `data_format` is invalid.
|
ValueError: if `data_format` is invalid.
|
||||||
"""
|
"""
|
||||||
return conv1d(input, # pylint: disable=redefined-builtin
|
return conv1d(
|
||||||
filters,
|
input, # pylint: disable=redefined-builtin
|
||||||
stride,
|
filters,
|
||||||
padding,
|
stride,
|
||||||
use_cudnn_on_gpu=True,
|
padding,
|
||||||
data_format=data_format,
|
use_cudnn_on_gpu=True,
|
||||||
name=name)
|
data_format=data_format,
|
||||||
|
name=name,
|
||||||
|
dilations=dilations)
|
||||||
|
|
||||||
|
|
||||||
def conv1d_transpose(
|
def conv1d_transpose(
|
||||||
|
@ -54,7 +54,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "conv1d"
|
name: "conv1d"
|
||||||
argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\', \'dilations\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "conv2d"
|
name: "conv2d"
|
||||||
|
@ -50,7 +50,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "conv1d"
|
name: "conv1d"
|
||||||
argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NWC\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "conv2d"
|
name: "conv2d"
|
||||||
|
Loading…
Reference in New Issue
Block a user