Method for conv_1d transpose updated to use dilations and exported
PiperOrigin-RevId: 233054694
This commit is contained in:
parent
f259cb4ac2
commit
bd8eb07ad5
@ -1678,6 +1678,20 @@ cuda_py_test(
|
|||||||
xla_enable_strict_auto_jit = True,
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "conv1d_transpose_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["conv1d_transpose_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//tensorflow/python:client",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:nn_ops",
|
||||||
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "conv2d_transpose_test",
|
name = "conv2d_transpose_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -68,7 +68,7 @@ class Conv1DTest(test.TestCase):
|
|||||||
f = constant_op.constant(
|
f = constant_op.constant(
|
||||||
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
||||||
output = nn_ops.conv1d_transpose(
|
output = nn_ops.conv1d_transpose(
|
||||||
x, f, y_shape, stride=stride, padding="VALID")
|
x, f, y_shape, strides=stride, padding="VALID")
|
||||||
value = self.evaluate(output)
|
value = self.evaluate(output)
|
||||||
|
|
||||||
cache_values = np.zeros(y_shape, dtype=np.float32)
|
cache_values = np.zeros(y_shape, dtype=np.float32)
|
||||||
|
260
tensorflow/python/kernel_tests/conv1d_transpose_test.py
Normal file
260
tensorflow/python/kernel_tests/conv1d_transpose_test.py
Normal file
@ -0,0 +1,260 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for convolution related functionality in tensorflow.ops.nn."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import gradient_checker
|
||||||
|
from tensorflow.python.ops import nn_ops
|
||||||
|
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1DTransposeTest(test.TestCase):
|
||||||
|
|
||||||
|
def testConv1DTransposeSingleStride(self):
|
||||||
|
with self.cached_session():
|
||||||
|
strides = [1, 1, 1]
|
||||||
|
|
||||||
|
# Input, output: [batch, width, depth]
|
||||||
|
x_shape = [2, 6, 3]
|
||||||
|
y_shape = [2, 6, 2]
|
||||||
|
|
||||||
|
# Filter: [kernel_width, output_depth, input_depth]
|
||||||
|
f_shape = [3, 2, 3]
|
||||||
|
|
||||||
|
x = constant_op.constant(
|
||||||
|
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
|
||||||
|
f = constant_op.constant(
|
||||||
|
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
||||||
|
output = nn_ops.conv1d_transpose(
|
||||||
|
x, f, y_shape, strides=strides, padding="SAME")
|
||||||
|
value = self.evaluate(output)
|
||||||
|
|
||||||
|
for n in xrange(y_shape[0]):
|
||||||
|
for w in xrange(y_shape[1]):
|
||||||
|
for c in xrange(y_shape[2]):
|
||||||
|
target = 2 * 3.0
|
||||||
|
w_in = w > 0 and w < y_shape[1] - 1
|
||||||
|
if w_in:
|
||||||
|
target += 3.0
|
||||||
|
self.assertAllClose(target, value[n, w, c])
|
||||||
|
|
||||||
|
def testConv1DTransposeSame(self):
|
||||||
|
with self.cached_session():
|
||||||
|
strides = [1, 2, 1]
|
||||||
|
|
||||||
|
# Input, output: [batch, width, depth]
|
||||||
|
x_shape = [2, 4, 3]
|
||||||
|
y_shape = [2, 8, 2]
|
||||||
|
|
||||||
|
# Filter: [kernel_width, output_depth, input_depth]
|
||||||
|
f_shape = [3, 2, 3]
|
||||||
|
|
||||||
|
x = constant_op.constant(
|
||||||
|
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
|
||||||
|
f = constant_op.constant(
|
||||||
|
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
||||||
|
output = nn_ops.conv1d_transpose(
|
||||||
|
x, f, y_shape, strides=strides, padding="SAME")
|
||||||
|
value = self.evaluate(output)
|
||||||
|
|
||||||
|
for n in xrange(x_shape[0]):
|
||||||
|
for k in xrange(f_shape[1]):
|
||||||
|
for w in xrange(y_shape[1]):
|
||||||
|
target = 3.0
|
||||||
|
# We add a case for locations divisible by the stride.
|
||||||
|
w_in = w % strides[1] == 0 and w > 0 and w < y_shape[1] - 1
|
||||||
|
if w_in:
|
||||||
|
target += 3.0
|
||||||
|
self.assertAllClose(target, value[n, w, k])
|
||||||
|
|
||||||
|
def testConv1DTransposeValid(self):
|
||||||
|
with self.cached_session():
|
||||||
|
strides = [1, 2, 1]
|
||||||
|
|
||||||
|
# Input, output: [batch, width, depth]
|
||||||
|
x_shape = [2, 4, 3]
|
||||||
|
y_shape = [2, 9, 2]
|
||||||
|
|
||||||
|
# Filter: [kernel_width, output_depth, input_depth]
|
||||||
|
f_shape = [3, 2, 3]
|
||||||
|
|
||||||
|
x = constant_op.constant(
|
||||||
|
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
|
||||||
|
f = constant_op.constant(
|
||||||
|
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
||||||
|
output = nn_ops.conv1d_transpose(
|
||||||
|
x, f, y_shape, strides=strides, padding="VALID")
|
||||||
|
value = self.evaluate(output)
|
||||||
|
|
||||||
|
cache_values = np.zeros(y_shape, dtype=np.float32)
|
||||||
|
|
||||||
|
# The amount of padding added
|
||||||
|
pad = 1
|
||||||
|
|
||||||
|
for n in xrange(x_shape[0]):
|
||||||
|
for k in xrange(f_shape[1]):
|
||||||
|
for w in xrange(pad, y_shape[1] - pad):
|
||||||
|
target = 3.0
|
||||||
|
# We add a case for locations divisible by the stride.
|
||||||
|
w_in = w % strides[1] == 0 and w > pad and w < y_shape[1] - 1 - pad
|
||||||
|
if w_in:
|
||||||
|
target += 3.0
|
||||||
|
cache_values[n, w, k] = target
|
||||||
|
|
||||||
|
# copy values in the border
|
||||||
|
cache_values[n, 0, k] = cache_values[n, 1, k]
|
||||||
|
cache_values[n, -1, k] = cache_values[n, -2, k]
|
||||||
|
cache_values[n, :, k] = cache_values[n, :, k]
|
||||||
|
|
||||||
|
self.assertAllClose(cache_values, value)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGradient(self):
|
||||||
|
x_shape = [2, 4, 3]
|
||||||
|
f_shape = [3, 2, 3]
|
||||||
|
y_shape = [2, 8, 2]
|
||||||
|
strides = [1, 2, 1]
|
||||||
|
np.random.seed(1) # Make it reproducible.
|
||||||
|
x_val = np.random.random_sample(x_shape).astype(np.float64)
|
||||||
|
f_val = np.random.random_sample(f_shape).astype(np.float64)
|
||||||
|
with self.cached_session():
|
||||||
|
x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
|
||||||
|
f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
|
||||||
|
output = nn_ops.conv1d_transpose(
|
||||||
|
x, f, y_shape, strides=strides, padding="SAME")
|
||||||
|
err = gradient_checker.compute_gradient_error([x, f], [x_shape, f_shape],
|
||||||
|
output, y_shape)
|
||||||
|
print("conv1d_transpose gradient err = %g " % err)
|
||||||
|
err_tolerance = 0.0005
|
||||||
|
self.assertLess(err, err_tolerance)
|
||||||
|
|
||||||
|
def testConv1DTransposeSingleStrideNCW(self):
|
||||||
|
# `NCW` data format is only supported for CUDA device.
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
with self.session(use_gpu=True):
|
||||||
|
strides = [1, 1, 1]
|
||||||
|
|
||||||
|
# Input, output: [batch, depth, width]
|
||||||
|
x_shape = [2, 3, 4]
|
||||||
|
y_shape = [2, 2, 4]
|
||||||
|
|
||||||
|
# Filter: [kernel_width, output_depth, input_depth]
|
||||||
|
f_shape = [3, 2, 3]
|
||||||
|
|
||||||
|
x = constant_op.constant(
|
||||||
|
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
|
||||||
|
f = constant_op.constant(
|
||||||
|
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
||||||
|
|
||||||
|
output = nn_ops.conv1d_transpose(
|
||||||
|
x, f, y_shape, strides=strides, padding="SAME", data_format="NCW")
|
||||||
|
|
||||||
|
value = self.evaluate(output)
|
||||||
|
for n in xrange(x_shape[0]):
|
||||||
|
for k in xrange(f_shape[1]):
|
||||||
|
for w in xrange(y_shape[2]):
|
||||||
|
target = 2 * 3.0
|
||||||
|
w_in = w > 0 and w < y_shape[2] - 1
|
||||||
|
if w_in:
|
||||||
|
target += 3.0
|
||||||
|
self.assertAllClose(target, value[n, k, w])
|
||||||
|
|
||||||
|
def testConv1DTransposeSameNCW(self):
|
||||||
|
# `NCW` data format is only supported for CUDA device.
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
with self.session(use_gpu=True):
|
||||||
|
strides = [1, 1, 2]
|
||||||
|
|
||||||
|
# Input, output: [batch, depth, width]
|
||||||
|
x_shape = [2, 3, 4]
|
||||||
|
y_shape = [2, 2, 8]
|
||||||
|
|
||||||
|
# Filter: [kernel_width, output_depth, input_depth]
|
||||||
|
f_shape = [3, 2, 3]
|
||||||
|
|
||||||
|
x = constant_op.constant(
|
||||||
|
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
|
||||||
|
f = constant_op.constant(
|
||||||
|
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
||||||
|
|
||||||
|
output = nn_ops.conv1d_transpose(
|
||||||
|
x, f, y_shape, strides=strides, padding="SAME", data_format="NCW")
|
||||||
|
|
||||||
|
value = self.evaluate(output)
|
||||||
|
for n in xrange(x_shape[0]):
|
||||||
|
for k in xrange(f_shape[1]):
|
||||||
|
for w in xrange(y_shape[2]):
|
||||||
|
target = 3.0
|
||||||
|
# We add a case for locations divisible by the stride.
|
||||||
|
w_in = w % strides[2] == 0 and w > 0 and w < y_shape[2] - 1
|
||||||
|
if w_in:
|
||||||
|
target += 3.0
|
||||||
|
self.assertAllClose(target, value[n, k, w])
|
||||||
|
|
||||||
|
def testConv1DTransposeValidNCW(self):
|
||||||
|
# `NCW` data format is only supported for CUDA device.
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
with self.session(use_gpu=True):
|
||||||
|
strides = [1, 1, 2]
|
||||||
|
|
||||||
|
# Input, output: [batch, depth, width]
|
||||||
|
x_shape = [2, 3, 4]
|
||||||
|
y_shape = [2, 2, 9]
|
||||||
|
|
||||||
|
# Filter: [kernel_width, output_depth, input_depth]
|
||||||
|
f_shape = [3, 2, 3]
|
||||||
|
|
||||||
|
x = constant_op.constant(
|
||||||
|
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
|
||||||
|
f = constant_op.constant(
|
||||||
|
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
||||||
|
output = nn_ops.conv1d_transpose(
|
||||||
|
x, f, y_shape, strides=strides, padding="VALID", data_format="NCW")
|
||||||
|
|
||||||
|
value = self.evaluate(output)
|
||||||
|
cache_values = np.zeros(y_shape, dtype=np.float32)
|
||||||
|
# The amount of padding added
|
||||||
|
pad = 1
|
||||||
|
for n in xrange(x_shape[0]):
|
||||||
|
for k in xrange(f_shape[1]):
|
||||||
|
for w in xrange(pad, y_shape[2] - pad):
|
||||||
|
target = 3.0
|
||||||
|
# We add a case for locations divisible by the stride.
|
||||||
|
w_in = w % strides[2] == 0 and w > pad and \
|
||||||
|
w < y_shape[2] - 1 - pad
|
||||||
|
if w_in:
|
||||||
|
target += 3.0
|
||||||
|
cache_values[n, k, w] = target
|
||||||
|
|
||||||
|
# copy values in the border
|
||||||
|
cache_values[n, k, 0] = cache_values[n, k, 1]
|
||||||
|
cache_values[n, k, -1] = cache_values[n, k, -2]
|
||||||
|
cache_values[n, k, :] = cache_values[n, k, :]
|
||||||
|
|
||||||
|
self.assertAllClose(cache_values, value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -4202,13 +4202,15 @@ def conv1d_v2(
|
|||||||
dilations=dilations)
|
dilations=dilations)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("nn.conv1d_transpose")
|
||||||
def conv1d_transpose(
|
def conv1d_transpose(
|
||||||
value,
|
input, # pylint: disable=redefined-builtin
|
||||||
filter, # pylint: disable=redefined-builtin
|
filters,
|
||||||
output_shape,
|
output_shape,
|
||||||
stride,
|
strides,
|
||||||
padding="SAME",
|
padding="SAME",
|
||||||
data_format="NWC",
|
data_format="NWC",
|
||||||
|
dilations=None,
|
||||||
name=None):
|
name=None):
|
||||||
"""The transpose of `conv1d`.
|
"""The transpose of `conv1d`.
|
||||||
|
|
||||||
@ -4218,19 +4220,23 @@ def conv1d_transpose(
|
|||||||
deconvolution.
|
deconvolution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value: A 3-D `Tensor` of type `float` and shape
|
input: A 3-D `Tensor` of type `float` and shape
|
||||||
`[batch, in_width, in_channels]` for `NWC` data format or
|
`[batch, in_width, in_channels]` for `NWC` data format or
|
||||||
`[batch, in_channels, in_width]` for `NCW` data format.
|
`[batch, in_channels, in_width]` for `NCW` data format.
|
||||||
filter: A 3-D `Tensor` with the same type as `value` and shape
|
filters: A 3-D `Tensor` with the same type as `value` and shape
|
||||||
`[filter_width, output_channels, in_channels]`. `filter`'s
|
`[filter_width, output_channels, in_channels]`. `filter`'s
|
||||||
`in_channels` dimension must match that of `value`.
|
`in_channels` dimension must match that of `value`.
|
||||||
output_shape: A 1-D `Tensor`, containing three elements, representing the
|
output_shape: A 1-D `Tensor`, containing three elements, representing the
|
||||||
output shape of the deconvolution op.
|
output shape of the deconvolution op.
|
||||||
stride: An `integer`. The number of entries by which
|
strides: An int or list of `ints` that has length `1` or `3`. The number of
|
||||||
the filter is moved right at each step.
|
entries by which the filter is moved right at each step.
|
||||||
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
|
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
|
||||||
See the "returns" section of `tf.nn.convolution` for details.
|
See the "returns" section of `tf.nn.convolution` for details.
|
||||||
data_format: A string. `'NWC'` and `'NCW'` are supported.
|
data_format: A string. `'NWC'` and `'NCW'` are supported.
|
||||||
|
dilations: An int or list of `ints` that has length `1` or `3` which
|
||||||
|
defaults to 1. The dilation factor for each dimension of input. If set to
|
||||||
|
k > 1, there will be k-1 skipped cells between each filter element on that
|
||||||
|
dimension. Dilations in the batch and depth dimensions must be 1.
|
||||||
name: Optional name for the returned tensor.
|
name: Optional name for the returned tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -4242,64 +4248,38 @@ def conv1d_transpose(
|
|||||||
`'VALID'` or `'SAME'`, or if `data_format` is invalid.
|
`'VALID'` or `'SAME'`, or if `data_format` is invalid.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "conv1d_transpose",
|
with ops.name_scope(name, "conv1d_transpose",
|
||||||
[value, filter, output_shape]) as name:
|
[input, filters, output_shape]) as name:
|
||||||
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
|
|
||||||
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(3)):
|
|
||||||
raise ValueError("output_shape must have shape (3,), got {}".format(
|
|
||||||
output_shape_.get_shape()))
|
|
||||||
|
|
||||||
# The format could be either NWC or NCW, map to NHWC or NCHW
|
# The format could be either NWC or NCW, map to NHWC or NCHW
|
||||||
if data_format is None or data_format == "NWC":
|
if data_format is None or data_format == "NWC":
|
||||||
data_format_2d = "NHWC"
|
data_format = "NHWC"
|
||||||
axis = 2
|
spatial_start_dim = 1
|
||||||
|
channel_index = 2
|
||||||
elif data_format == "NCW":
|
elif data_format == "NCW":
|
||||||
data_format_2d = "NCHW"
|
data_format = "NCHW"
|
||||||
axis = 1
|
spatial_start_dim = 2
|
||||||
|
channel_index = 1
|
||||||
else:
|
else:
|
||||||
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
|
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
|
||||||
|
|
||||||
if not value.get_shape().dims[axis].is_compatible_with(
|
|
||||||
filter.get_shape()[2]):
|
|
||||||
raise ValueError("input channels does not match filter's input channels, "
|
|
||||||
"{} != {}".format(value.get_shape()[axis],
|
|
||||||
filter.get_shape()[2]))
|
|
||||||
|
|
||||||
if isinstance(output_shape, (list, np.ndarray)):
|
|
||||||
# output_shape's shape should be == [3] if reached this point.
|
|
||||||
if not filter.get_shape().dims[1].is_compatible_with(
|
|
||||||
output_shape[axis]):
|
|
||||||
raise ValueError(
|
|
||||||
"output_shape does not match filter's output channels, "
|
|
||||||
"{} != {}".format(output_shape[axis],
|
|
||||||
filter.get_shape()[1]))
|
|
||||||
|
|
||||||
if padding != "VALID" and padding != "SAME":
|
|
||||||
raise ValueError("padding must be either VALID or SAME:"
|
|
||||||
" {}".format(padding))
|
|
||||||
|
|
||||||
# Reshape the input tensor to [batch, 1, in_width, in_channels]
|
# Reshape the input tensor to [batch, 1, in_width, in_channels]
|
||||||
if data_format_2d == "NHWC":
|
strides = [1] + _get_sequence(strides, 1, channel_index, "stride")
|
||||||
output_shape_ = array_ops.concat(
|
dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
|
||||||
[output_shape_[:1], [1], output_shape_[1:]], axis=0)
|
|
||||||
spatial_start_dim = 1
|
input = array_ops.expand_dims(input, spatial_start_dim)
|
||||||
strides = [1, 1, stride, 1]
|
filters = array_ops.expand_dims(filters, 0) # pylint: disable=redefined-builtin
|
||||||
else:
|
output_shape = output_shape[: spatial_start_dim] + [1] + \
|
||||||
output_shape_ = array_ops.concat(
|
output_shape[spatial_start_dim:]
|
||||||
[output_shape_[:2], [1], output_shape_[2:]], axis=0)
|
|
||||||
spatial_start_dim = 2
|
|
||||||
strides = [1, 1, 1, stride]
|
|
||||||
value = array_ops.expand_dims(value, spatial_start_dim)
|
|
||||||
filter = array_ops.expand_dims(filter, 0) # pylint: disable=redefined-builtin
|
|
||||||
|
|
||||||
result = gen_nn_ops.conv2d_backprop_input(
|
result = gen_nn_ops.conv2d_backprop_input(
|
||||||
input_sizes=output_shape_,
|
input_sizes=output_shape,
|
||||||
filter=filter,
|
filter=filters,
|
||||||
out_backprop=value,
|
out_backprop=input,
|
||||||
strides=strides,
|
strides=strides,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
data_format=data_format_2d,
|
data_format=data_format,
|
||||||
|
dilations=dilations,
|
||||||
name=name)
|
name=name)
|
||||||
return array_ops.squeeze(result, [spatial_start_dim])
|
return array_ops.squeeze(result, spatial_start_dim)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterStatistics("Dilation2D", "flops")
|
@ops.RegisterStatistics("Dilation2D", "flops")
|
||||||
|
@ -56,6 +56,10 @@ tf_module {
|
|||||||
name: "conv1d"
|
name: "conv1d"
|
||||||
argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\', \'dilations\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\', \'dilations\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "conv1d_transpose"
|
||||||
|
argspec: "args=[\'input\', \'filters\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'SAME\', \'NWC\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "conv2d"
|
name: "conv2d"
|
||||||
argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'dilations\', \'name\', \'filters\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\', \'None\'], "
|
argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'dilations\', \'name\', \'filters\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\', \'None\'], "
|
||||||
|
@ -52,6 +52,10 @@ tf_module {
|
|||||||
name: "conv1d"
|
name: "conv1d"
|
||||||
argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NWC\', \'None\', \'None\'], "
|
argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NWC\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "conv1d_transpose"
|
||||||
|
argspec: "args=[\'input\', \'filters\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'SAME\', \'NWC\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "conv2d"
|
name: "conv2d"
|
||||||
argspec: "args=[\'input\', \'filters\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\', \'None\'], "
|
argspec: "args=[\'input\', \'filters\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\', \'None\'], "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user