Method for conv_1d transpose updated to use dilations and exported

PiperOrigin-RevId: 233054694
This commit is contained in:
Tamara Norman 2019-02-08 07:10:32 -08:00 committed by TensorFlower Gardener
parent f259cb4ac2
commit bd8eb07ad5
6 changed files with 316 additions and 54 deletions

View File

@ -1678,6 +1678,20 @@ cuda_py_test(
xla_enable_strict_auto_jit = True,
)
cuda_py_test(
name = "conv1d_transpose_test",
size = "small",
srcs = ["conv1d_transpose_test.py"],
additional_deps = [
"//third_party/py/numpy",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn_ops",
],
xla_enable_strict_auto_jit = True,
)
cuda_py_test(
name = "conv2d_transpose_test",
size = "small",

View File

@ -68,7 +68,7 @@ class Conv1DTest(test.TestCase):
f = constant_op.constant(
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv1d_transpose(
x, f, y_shape, stride=stride, padding="VALID")
x, f, y_shape, strides=stride, padding="VALID")
value = self.evaluate(output)
cache_values = np.zeros(y_shape, dtype=np.float32)

View File

@ -0,0 +1,260 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for convolution related functionality in tensorflow.ops.nn."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
class Conv1DTransposeTest(test.TestCase):
def testConv1DTransposeSingleStride(self):
with self.cached_session():
strides = [1, 1, 1]
# Input, output: [batch, width, depth]
x_shape = [2, 6, 3]
y_shape = [2, 6, 2]
# Filter: [kernel_width, output_depth, input_depth]
f_shape = [3, 2, 3]
x = constant_op.constant(
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
f = constant_op.constant(
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv1d_transpose(
x, f, y_shape, strides=strides, padding="SAME")
value = self.evaluate(output)
for n in xrange(y_shape[0]):
for w in xrange(y_shape[1]):
for c in xrange(y_shape[2]):
target = 2 * 3.0
w_in = w > 0 and w < y_shape[1] - 1
if w_in:
target += 3.0
self.assertAllClose(target, value[n, w, c])
def testConv1DTransposeSame(self):
with self.cached_session():
strides = [1, 2, 1]
# Input, output: [batch, width, depth]
x_shape = [2, 4, 3]
y_shape = [2, 8, 2]
# Filter: [kernel_width, output_depth, input_depth]
f_shape = [3, 2, 3]
x = constant_op.constant(
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
f = constant_op.constant(
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv1d_transpose(
x, f, y_shape, strides=strides, padding="SAME")
value = self.evaluate(output)
for n in xrange(x_shape[0]):
for k in xrange(f_shape[1]):
for w in xrange(y_shape[1]):
target = 3.0
# We add a case for locations divisible by the stride.
w_in = w % strides[1] == 0 and w > 0 and w < y_shape[1] - 1
if w_in:
target += 3.0
self.assertAllClose(target, value[n, w, k])
def testConv1DTransposeValid(self):
with self.cached_session():
strides = [1, 2, 1]
# Input, output: [batch, width, depth]
x_shape = [2, 4, 3]
y_shape = [2, 9, 2]
# Filter: [kernel_width, output_depth, input_depth]
f_shape = [3, 2, 3]
x = constant_op.constant(
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
f = constant_op.constant(
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv1d_transpose(
x, f, y_shape, strides=strides, padding="VALID")
value = self.evaluate(output)
cache_values = np.zeros(y_shape, dtype=np.float32)
# The amount of padding added
pad = 1
for n in xrange(x_shape[0]):
for k in xrange(f_shape[1]):
for w in xrange(pad, y_shape[1] - pad):
target = 3.0
# We add a case for locations divisible by the stride.
w_in = w % strides[1] == 0 and w > pad and w < y_shape[1] - 1 - pad
if w_in:
target += 3.0
cache_values[n, w, k] = target
# copy values in the border
cache_values[n, 0, k] = cache_values[n, 1, k]
cache_values[n, -1, k] = cache_values[n, -2, k]
cache_values[n, :, k] = cache_values[n, :, k]
self.assertAllClose(cache_values, value)
@test_util.run_deprecated_v1
def testGradient(self):
x_shape = [2, 4, 3]
f_shape = [3, 2, 3]
y_shape = [2, 8, 2]
strides = [1, 2, 1]
np.random.seed(1) # Make it reproducible.
x_val = np.random.random_sample(x_shape).astype(np.float64)
f_val = np.random.random_sample(f_shape).astype(np.float64)
with self.cached_session():
x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
output = nn_ops.conv1d_transpose(
x, f, y_shape, strides=strides, padding="SAME")
err = gradient_checker.compute_gradient_error([x, f], [x_shape, f_shape],
output, y_shape)
print("conv1d_transpose gradient err = %g " % err)
err_tolerance = 0.0005
self.assertLess(err, err_tolerance)
def testConv1DTransposeSingleStrideNCW(self):
# `NCW` data format is only supported for CUDA device.
if test.is_gpu_available(cuda_only=True):
with self.session(use_gpu=True):
strides = [1, 1, 1]
# Input, output: [batch, depth, width]
x_shape = [2, 3, 4]
y_shape = [2, 2, 4]
# Filter: [kernel_width, output_depth, input_depth]
f_shape = [3, 2, 3]
x = constant_op.constant(
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
f = constant_op.constant(
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv1d_transpose(
x, f, y_shape, strides=strides, padding="SAME", data_format="NCW")
value = self.evaluate(output)
for n in xrange(x_shape[0]):
for k in xrange(f_shape[1]):
for w in xrange(y_shape[2]):
target = 2 * 3.0
w_in = w > 0 and w < y_shape[2] - 1
if w_in:
target += 3.0
self.assertAllClose(target, value[n, k, w])
def testConv1DTransposeSameNCW(self):
# `NCW` data format is only supported for CUDA device.
if test.is_gpu_available(cuda_only=True):
with self.session(use_gpu=True):
strides = [1, 1, 2]
# Input, output: [batch, depth, width]
x_shape = [2, 3, 4]
y_shape = [2, 2, 8]
# Filter: [kernel_width, output_depth, input_depth]
f_shape = [3, 2, 3]
x = constant_op.constant(
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
f = constant_op.constant(
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv1d_transpose(
x, f, y_shape, strides=strides, padding="SAME", data_format="NCW")
value = self.evaluate(output)
for n in xrange(x_shape[0]):
for k in xrange(f_shape[1]):
for w in xrange(y_shape[2]):
target = 3.0
# We add a case for locations divisible by the stride.
w_in = w % strides[2] == 0 and w > 0 and w < y_shape[2] - 1
if w_in:
target += 3.0
self.assertAllClose(target, value[n, k, w])
def testConv1DTransposeValidNCW(self):
# `NCW` data format is only supported for CUDA device.
if test.is_gpu_available(cuda_only=True):
with self.session(use_gpu=True):
strides = [1, 1, 2]
# Input, output: [batch, depth, width]
x_shape = [2, 3, 4]
y_shape = [2, 2, 9]
# Filter: [kernel_width, output_depth, input_depth]
f_shape = [3, 2, 3]
x = constant_op.constant(
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
f = constant_op.constant(
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv1d_transpose(
x, f, y_shape, strides=strides, padding="VALID", data_format="NCW")
value = self.evaluate(output)
cache_values = np.zeros(y_shape, dtype=np.float32)
# The amount of padding added
pad = 1
for n in xrange(x_shape[0]):
for k in xrange(f_shape[1]):
for w in xrange(pad, y_shape[2] - pad):
target = 3.0
# We add a case for locations divisible by the stride.
w_in = w % strides[2] == 0 and w > pad and \
w < y_shape[2] - 1 - pad
if w_in:
target += 3.0
cache_values[n, k, w] = target
# copy values in the border
cache_values[n, k, 0] = cache_values[n, k, 1]
cache_values[n, k, -1] = cache_values[n, k, -2]
cache_values[n, k, :] = cache_values[n, k, :]
self.assertAllClose(cache_values, value)
if __name__ == "__main__":
test.main()

View File

@ -4202,13 +4202,15 @@ def conv1d_v2(
dilations=dilations)
@tf_export("nn.conv1d_transpose")
def conv1d_transpose(
value,
filter, # pylint: disable=redefined-builtin
input, # pylint: disable=redefined-builtin
filters,
output_shape,
stride,
strides,
padding="SAME",
data_format="NWC",
dilations=None,
name=None):
"""The transpose of `conv1d`.
@ -4218,19 +4220,23 @@ def conv1d_transpose(
deconvolution.
Args:
value: A 3-D `Tensor` of type `float` and shape
input: A 3-D `Tensor` of type `float` and shape
`[batch, in_width, in_channels]` for `NWC` data format or
`[batch, in_channels, in_width]` for `NCW` data format.
filter: A 3-D `Tensor` with the same type as `value` and shape
filters: A 3-D `Tensor` with the same type as `value` and shape
`[filter_width, output_channels, in_channels]`. `filter`'s
`in_channels` dimension must match that of `value`.
output_shape: A 1-D `Tensor`, containing three elements, representing the
output shape of the deconvolution op.
stride: An `integer`. The number of entries by which
the filter is moved right at each step.
strides: An int or list of `ints` that has length `1` or `3`. The number of
entries by which the filter is moved right at each step.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
See the "returns" section of `tf.nn.convolution` for details.
data_format: A string. `'NWC'` and `'NCW'` are supported.
dilations: An int or list of `ints` that has length `1` or `3` which
defaults to 1. The dilation factor for each dimension of input. If set to
k > 1, there will be k-1 skipped cells between each filter element on that
dimension. Dilations in the batch and depth dimensions must be 1.
name: Optional name for the returned tensor.
Returns:
@ -4242,64 +4248,38 @@ def conv1d_transpose(
`'VALID'` or `'SAME'`, or if `data_format` is invalid.
"""
with ops.name_scope(name, "conv1d_transpose",
[value, filter, output_shape]) as name:
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(3)):
raise ValueError("output_shape must have shape (3,), got {}".format(
output_shape_.get_shape()))
[input, filters, output_shape]) as name:
# The format could be either NWC or NCW, map to NHWC or NCHW
if data_format is None or data_format == "NWC":
data_format_2d = "NHWC"
axis = 2
data_format = "NHWC"
spatial_start_dim = 1
channel_index = 2
elif data_format == "NCW":
data_format_2d = "NCHW"
axis = 1
data_format = "NCHW"
spatial_start_dim = 2
channel_index = 1
else:
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
if not value.get_shape().dims[axis].is_compatible_with(
filter.get_shape()[2]):
raise ValueError("input channels does not match filter's input channels, "
"{} != {}".format(value.get_shape()[axis],
filter.get_shape()[2]))
if isinstance(output_shape, (list, np.ndarray)):
# output_shape's shape should be == [3] if reached this point.
if not filter.get_shape().dims[1].is_compatible_with(
output_shape[axis]):
raise ValueError(
"output_shape does not match filter's output channels, "
"{} != {}".format(output_shape[axis],
filter.get_shape()[1]))
if padding != "VALID" and padding != "SAME":
raise ValueError("padding must be either VALID or SAME:"
" {}".format(padding))
# Reshape the input tensor to [batch, 1, in_width, in_channels]
if data_format_2d == "NHWC":
output_shape_ = array_ops.concat(
[output_shape_[:1], [1], output_shape_[1:]], axis=0)
spatial_start_dim = 1
strides = [1, 1, stride, 1]
else:
output_shape_ = array_ops.concat(
[output_shape_[:2], [1], output_shape_[2:]], axis=0)
spatial_start_dim = 2
strides = [1, 1, 1, stride]
value = array_ops.expand_dims(value, spatial_start_dim)
filter = array_ops.expand_dims(filter, 0) # pylint: disable=redefined-builtin
strides = [1] + _get_sequence(strides, 1, channel_index, "stride")
dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
input = array_ops.expand_dims(input, spatial_start_dim)
filters = array_ops.expand_dims(filters, 0) # pylint: disable=redefined-builtin
output_shape = output_shape[: spatial_start_dim] + [1] + \
output_shape[spatial_start_dim:]
result = gen_nn_ops.conv2d_backprop_input(
input_sizes=output_shape_,
filter=filter,
out_backprop=value,
input_sizes=output_shape,
filter=filters,
out_backprop=input,
strides=strides,
padding=padding,
data_format=data_format_2d,
data_format=data_format,
dilations=dilations,
name=name)
return array_ops.squeeze(result, [spatial_start_dim])
return array_ops.squeeze(result, spatial_start_dim)
@ops.RegisterStatistics("Dilation2D", "flops")

View File

@ -56,6 +56,10 @@ tf_module {
name: "conv1d"
argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\', \'dilations\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "conv1d_transpose"
argspec: "args=[\'input\', \'filters\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'SAME\', \'NWC\', \'None\', \'None\'], "
}
member_method {
name: "conv2d"
argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'dilations\', \'name\', \'filters\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\', \'None\'], "

View File

@ -52,6 +52,10 @@ tf_module {
name: "conv1d"
argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NWC\', \'None\', \'None\'], "
}
member_method {
name: "conv1d_transpose"
argspec: "args=[\'input\', \'filters\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'SAME\', \'NWC\', \'None\', \'None\'], "
}
member_method {
name: "conv2d"
argspec: "args=[\'input\', \'filters\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\', \'None\'], "