Standardise inputs for the conv1d and add the ability to perform dilations

PiperOrigin-RevId: 232507407
This commit is contained in:
Tamara Norman 2019-02-05 10:04:57 -08:00 committed by TensorFlower Gardener
parent f6a39be9f4
commit e1f2db44f4
3 changed files with 49 additions and 30 deletions

View File

@ -3999,14 +3999,16 @@ def fractional_avg_pool_v2(value,
"`NHWC` for data_format is deprecated, use `NWC` instead",
warn_once=True,
data_format="NHWC")
def conv1d(value=None,
filters=None,
stride=None,
padding=None,
def conv1d(
value,
filters,
stride,
padding,
use_cudnn_on_gpu=None,
data_format=None,
name=None,
input=None): # pylint: disable=redefined-builtin
input=None, # pylint: disable=redefined-builtin
dilations=None):
r"""Computes a 1-D convolution given 3-D input and filter tensors.
Given an input tensor of shape
@ -4034,8 +4036,8 @@ def conv1d(value=None,
Args:
value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
filters: A 3D `Tensor`. Must have the same type as `value`.
stride: An `integer`. The number of entries by which
the filter is moved right at each step.
stride: An int or list of `ints` that has length `1` or `3`. The number of
entries by which the filter is moved right at each step.
padding: 'SAME' or 'VALID'
use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
data_format: An optional `string` from `"NWC", "NCW"`. Defaults
@ -4044,6 +4046,10 @@ def conv1d(value=None,
data as [batch, in_channels, in_width].
name: A name for the operation (optional).
input: Alias for value.
dilations: An int or list of `ints` that has length `1` or `3` which
defaults to 1. The dilation factor for each dimension of input. If set to
k > 1, there will be k-1 skipped cells between each filter element on that
dimension. Dilations in the batch and depth dimensions must be 1.
Returns:
A `Tensor`. Has the same type as input.
@ -4057,13 +4063,16 @@ def conv1d(value=None,
if data_format is None or data_format == "NHWC" or data_format == "NWC":
data_format = "NHWC"
spatial_start_dim = 1
strides = [1, 1, stride, 1]
channel_index = 2
elif data_format == "NCHW" or data_format == "NCW":
data_format = "NCHW"
spatial_start_dim = 2
strides = [1, 1, 1, stride]
channel_index = 1
else:
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
strides = [1] + _get_sequence(stride, 1, channel_index, "stride")
dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
value = array_ops.expand_dims(value, spatial_start_dim)
filters = array_ops.expand_dims(filters, 0)
result = gen_nn_ops.conv2d(
@ -4072,16 +4081,20 @@ def conv1d(value=None,
strides,
padding,
use_cudnn_on_gpu=use_cudnn_on_gpu,
data_format=data_format)
data_format=data_format,
dilations=dilations,
name=name)
return array_ops.squeeze(result, [spatial_start_dim])
@tf_export("nn.conv1d", v1=[])
def conv1d_v2(input, # pylint: disable=redefined-builtin
def conv1d_v2(
input, # pylint: disable=redefined-builtin
filters,
stride,
padding,
data_format=None,
data_format="NWC",
dilations=None,
name=None):
r"""Computes a 1-D convolution given 3-D input and filter tensors.
@ -4110,13 +4123,17 @@ def conv1d_v2(input, # pylint: disable=redefined-builtin
Args:
input: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
filters: A 3D `Tensor`. Must have the same type as `input`.
stride: An `integer`. The number of entries by which
the filter is moved right at each step.
stride: An int or list of `ints` that has length `1` or `3`. The number of
entries by which the filter is moved right at each step.
padding: 'SAME' or 'VALID'
data_format: An optional `string` from `"NWC", "NCW"`. Defaults
to `"NWC"`, the data is stored in the order of
[batch, in_width, in_channels]. The `"NCW"` format stores
data as [batch, in_channels, in_width].
dilations: An int or list of `ints` that has length `1` or `3` which
defaults to 1. The dilation factor for each dimension of input. If set to
k > 1, there will be k-1 skipped cells between each filter element on that
dimension. Dilations in the batch and depth dimensions must be 1.
name: A name for the operation (optional).
Returns:
@ -4125,13 +4142,15 @@ def conv1d_v2(input, # pylint: disable=redefined-builtin
Raises:
ValueError: if `data_format` is invalid.
"""
return conv1d(input, # pylint: disable=redefined-builtin
return conv1d(
input, # pylint: disable=redefined-builtin
filters,
stride,
padding,
use_cudnn_on_gpu=True,
data_format=data_format,
name=name)
name=name,
dilations=dilations)
def conv1d_transpose(

View File

@ -54,7 +54,7 @@ tf_module {
}
member_method {
name: "conv1d"
argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\', \'dilations\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "conv2d"

View File

@ -50,7 +50,7 @@ tf_module {
}
member_method {
name: "conv1d"
argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NWC\', \'None\', \'None\'], "
}
member_method {
name: "conv2d"