Standardise inputs for the conv1d and add the ability to perform dilations

PiperOrigin-RevId: 232507407
This commit is contained in:
Tamara Norman 2019-02-05 10:04:57 -08:00 committed by TensorFlower Gardener
parent f6a39be9f4
commit e1f2db44f4
3 changed files with 49 additions and 30 deletions

View File

@ -3999,14 +3999,16 @@ def fractional_avg_pool_v2(value,
"`NHWC` for data_format is deprecated, use `NWC` instead", "`NHWC` for data_format is deprecated, use `NWC` instead",
warn_once=True, warn_once=True,
data_format="NHWC") data_format="NHWC")
def conv1d(value=None, def conv1d(
filters=None, value,
stride=None, filters,
padding=None, stride,
use_cudnn_on_gpu=None, padding,
data_format=None, use_cudnn_on_gpu=None,
name=None, data_format=None,
input=None): # pylint: disable=redefined-builtin name=None,
input=None, # pylint: disable=redefined-builtin
dilations=None):
r"""Computes a 1-D convolution given 3-D input and filter tensors. r"""Computes a 1-D convolution given 3-D input and filter tensors.
Given an input tensor of shape Given an input tensor of shape
@ -4034,8 +4036,8 @@ def conv1d(value=None,
Args: Args:
value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`. value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
filters: A 3D `Tensor`. Must have the same type as `value`. filters: A 3D `Tensor`. Must have the same type as `value`.
stride: An `integer`. The number of entries by which stride: An int or list of `ints` that has length `1` or `3`. The number of
the filter is moved right at each step. entries by which the filter is moved right at each step.
padding: 'SAME' or 'VALID' padding: 'SAME' or 'VALID'
use_cudnn_on_gpu: An optional `bool`. Defaults to `True`. use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
data_format: An optional `string` from `"NWC", "NCW"`. Defaults data_format: An optional `string` from `"NWC", "NCW"`. Defaults
@ -4044,6 +4046,10 @@ def conv1d(value=None,
data as [batch, in_channels, in_width]. data as [batch, in_channels, in_width].
name: A name for the operation (optional). name: A name for the operation (optional).
input: Alias for value. input: Alias for value.
dilations: An int or list of `ints` that has length `1` or `3` which
defaults to 1. The dilation factor for each dimension of input. If set to
k > 1, there will be k-1 skipped cells between each filter element on that
dimension. Dilations in the batch and depth dimensions must be 1.
Returns: Returns:
A `Tensor`. Has the same type as input. A `Tensor`. Has the same type as input.
@ -4057,13 +4063,16 @@ def conv1d(value=None,
if data_format is None or data_format == "NHWC" or data_format == "NWC": if data_format is None or data_format == "NHWC" or data_format == "NWC":
data_format = "NHWC" data_format = "NHWC"
spatial_start_dim = 1 spatial_start_dim = 1
strides = [1, 1, stride, 1] channel_index = 2
elif data_format == "NCHW" or data_format == "NCW": elif data_format == "NCHW" or data_format == "NCW":
data_format = "NCHW" data_format = "NCHW"
spatial_start_dim = 2 spatial_start_dim = 2
strides = [1, 1, 1, stride] channel_index = 1
else: else:
raise ValueError("data_format must be \"NWC\" or \"NCW\".") raise ValueError("data_format must be \"NWC\" or \"NCW\".")
strides = [1] + _get_sequence(stride, 1, channel_index, "stride")
dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
value = array_ops.expand_dims(value, spatial_start_dim) value = array_ops.expand_dims(value, spatial_start_dim)
filters = array_ops.expand_dims(filters, 0) filters = array_ops.expand_dims(filters, 0)
result = gen_nn_ops.conv2d( result = gen_nn_ops.conv2d(
@ -4072,17 +4081,21 @@ def conv1d(value=None,
strides, strides,
padding, padding,
use_cudnn_on_gpu=use_cudnn_on_gpu, use_cudnn_on_gpu=use_cudnn_on_gpu,
data_format=data_format) data_format=data_format,
dilations=dilations,
name=name)
return array_ops.squeeze(result, [spatial_start_dim]) return array_ops.squeeze(result, [spatial_start_dim])
@tf_export("nn.conv1d", v1=[]) @tf_export("nn.conv1d", v1=[])
def conv1d_v2(input, # pylint: disable=redefined-builtin def conv1d_v2(
filters, input, # pylint: disable=redefined-builtin
stride, filters,
padding, stride,
data_format=None, padding,
name=None): data_format="NWC",
dilations=None,
name=None):
r"""Computes a 1-D convolution given 3-D input and filter tensors. r"""Computes a 1-D convolution given 3-D input and filter tensors.
Given an input tensor of shape Given an input tensor of shape
@ -4110,13 +4123,17 @@ def conv1d_v2(input, # pylint: disable=redefined-builtin
Args: Args:
input: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`. input: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
filters: A 3D `Tensor`. Must have the same type as `input`. filters: A 3D `Tensor`. Must have the same type as `input`.
stride: An `integer`. The number of entries by which stride: An int or list of `ints` that has length `1` or `3`. The number of
the filter is moved right at each step. entries by which the filter is moved right at each step.
padding: 'SAME' or 'VALID' padding: 'SAME' or 'VALID'
data_format: An optional `string` from `"NWC", "NCW"`. Defaults data_format: An optional `string` from `"NWC", "NCW"`. Defaults
to `"NWC"`, the data is stored in the order of to `"NWC"`, the data is stored in the order of
[batch, in_width, in_channels]. The `"NCW"` format stores [batch, in_width, in_channels]. The `"NCW"` format stores
data as [batch, in_channels, in_width]. data as [batch, in_channels, in_width].
dilations: An int or list of `ints` that has length `1` or `3` which
defaults to 1. The dilation factor for each dimension of input. If set to
k > 1, there will be k-1 skipped cells between each filter element on that
dimension. Dilations in the batch and depth dimensions must be 1.
name: A name for the operation (optional). name: A name for the operation (optional).
Returns: Returns:
@ -4125,13 +4142,15 @@ def conv1d_v2(input, # pylint: disable=redefined-builtin
Raises: Raises:
ValueError: if `data_format` is invalid. ValueError: if `data_format` is invalid.
""" """
return conv1d(input, # pylint: disable=redefined-builtin return conv1d(
filters, input, # pylint: disable=redefined-builtin
stride, filters,
padding, stride,
use_cudnn_on_gpu=True, padding,
data_format=data_format, use_cudnn_on_gpu=True,
name=name) data_format=data_format,
name=name,
dilations=dilations)
def conv1d_transpose( def conv1d_transpose(

View File

@ -54,7 +54,7 @@ tf_module {
} }
member_method { member_method {
name: "conv1d" name: "conv1d"
argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\', \'dilations\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "conv2d" name: "conv2d"

View File

@ -50,7 +50,7 @@ tf_module {
} }
member_method { member_method {
name: "conv1d" name: "conv1d"
argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NWC\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "conv2d" name: "conv2d"