Add data_format support for N-D convolution and pooling.
This adds support for "NC*" data layouts for N-D convolution and pooling (including atrous convolution and pooling); previously, only "N*C" data lyaouts were supported. This also adds support for 1-D pooling (by forwarding to the 2-D implementation), and fixes the broken data_format support in conv1d. Change: 136556507
This commit is contained in:
parent
45010d6a49
commit
8fa9b949dc
@ -48,15 +48,15 @@ def upsample_filters(filters, rate):
|
|||||||
|
|
||||||
class AtrousConvolutionTest(tf.test.TestCase):
|
class AtrousConvolutionTest(tf.test.TestCase):
|
||||||
|
|
||||||
def _test_atrous_convolution(self, input_shape, filter_shape, padding,
|
def _test_atrous_convolution(self, input_shape, filter_shape, dilation_rate,
|
||||||
dilation_rate):
|
**kwargs):
|
||||||
filters = np.arange(
|
filters = np.arange(
|
||||||
np.prod(filter_shape), dtype=np.float32).reshape(filter_shape)
|
np.prod(filter_shape), dtype=np.float32).reshape(filter_shape)
|
||||||
filters_upsampled = upsample_filters(filters, dilation_rate)
|
filters_upsampled = upsample_filters(filters, dilation_rate)
|
||||||
x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape)
|
x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape)
|
||||||
y1 = tf.nn.convolution(
|
y1 = tf.nn.convolution(
|
||||||
input=x, filter=filters, padding=padding, dilation_rate=dilation_rate)
|
input=x, filter=filters, dilation_rate=dilation_rate, **kwargs)
|
||||||
y2 = tf.nn.convolution(input=x, filter=filters_upsampled, padding=padding)
|
y2 = tf.nn.convolution(input=x, filter=filters_upsampled, **kwargs)
|
||||||
self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-2, atol=1e-2)
|
self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
def testAtrousConvolution2D(self):
|
def testAtrousConvolution2D(self):
|
||||||
@ -99,6 +99,24 @@ class AtrousConvolutionTest(tf.test.TestCase):
|
|||||||
padding=padding,
|
padding=padding,
|
||||||
dilation_rate=[rate])
|
dilation_rate=[rate])
|
||||||
|
|
||||||
|
def testAtrousConvolutionNC(self):
|
||||||
|
if tf.test.is_gpu_available():
|
||||||
|
# "NCW" and "NCHW" formats are not currently supported on CPU.
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
for padding in ["SAME", "VALID"]:
|
||||||
|
self._test_atrous_convolution(
|
||||||
|
input_shape=[2, 2, 9],
|
||||||
|
padding=padding,
|
||||||
|
filter_shape=[3, 2, 2],
|
||||||
|
dilation_rate=[2],
|
||||||
|
data_format="NCW")
|
||||||
|
self._test_atrous_convolution(
|
||||||
|
input_shape=[2, 2, 9, 5],
|
||||||
|
padding=padding,
|
||||||
|
filter_shape=[3, 3, 2, 2],
|
||||||
|
dilation_rate=[2, 1],
|
||||||
|
data_format="NCHW")
|
||||||
|
|
||||||
def testAtrousSequence(self):
|
def testAtrousSequence(self):
|
||||||
"""Tests optimization of sequence of atrous convolutions.
|
"""Tests optimization of sequence of atrous convolutions.
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ def pool_direct_single_axis(input, # pylint: disable=redefined-builtin
|
|||||||
|
|
||||||
|
|
||||||
def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=redefined-builtin
|
def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=redefined-builtin
|
||||||
dilation_rate, strides):
|
dilation_rate, strides, data_format=None):
|
||||||
"""Numpy implementation of pooling.
|
"""Numpy implementation of pooling.
|
||||||
|
|
||||||
This is intended for testing only, and therefore isn't particularly efficient.
|
This is intended for testing only, and therefore isn't particularly efficient.
|
||||||
@ -103,6 +103,8 @@ def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=r
|
|||||||
padding: either "SAME" or "VALID".
|
padding: either "SAME" or "VALID".
|
||||||
dilation_rate: Sequence of N ints >= 1.
|
dilation_rate: Sequence of N ints >= 1.
|
||||||
strides: Sequence of N ints >= 1.
|
strides: Sequence of N ints >= 1.
|
||||||
|
data_format: If specified and starts with "NC", indicates that second
|
||||||
|
dimension, rather than the last dimension, specifies the channel.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
pooling output array of rank N+2.
|
pooling output array of rank N+2.
|
||||||
@ -110,11 +112,15 @@ def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=r
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if arguments are invalid.
|
ValueError: if arguments are invalid.
|
||||||
"""
|
"""
|
||||||
|
if data_format is None or not data_format.startswith("NC"):
|
||||||
|
spatial_start_dim = 1
|
||||||
|
else:
|
||||||
|
spatial_start_dim = 2
|
||||||
output = input
|
output = input
|
||||||
for i in range(len(window_shape)):
|
for i in range(len(window_shape)):
|
||||||
output = pool_direct_single_axis(
|
output = pool_direct_single_axis(
|
||||||
input=output,
|
input=output,
|
||||||
axis=i + 1,
|
axis=i + spatial_start_dim,
|
||||||
window_size=window_shape[i],
|
window_size=window_shape[i],
|
||||||
pooling_type=pooling_type,
|
pooling_type=pooling_type,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
@ -125,26 +131,13 @@ def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=r
|
|||||||
|
|
||||||
class PoolingTest(tf.test.TestCase):
|
class PoolingTest(tf.test.TestCase):
|
||||||
|
|
||||||
def _test(self, input_shape, window_shape, pooling_type, padding,
|
def _test(self, input_shape, **kwargs):
|
||||||
dilation_rate, strides):
|
|
||||||
# Use negative numbers to make sure there isn't any zero padding getting
|
# Use negative numbers to make sure there isn't any zero padding getting
|
||||||
# used.
|
# used.
|
||||||
x = -np.arange(
|
x = -np.arange(
|
||||||
np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
|
np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
|
||||||
y1 = pool_direct(
|
y1 = pool_direct(input=x, **kwargs)
|
||||||
input=x,
|
y2 = tf.nn.pool(input=x, **kwargs)
|
||||||
window_shape=window_shape,
|
|
||||||
pooling_type=pooling_type,
|
|
||||||
padding=padding,
|
|
||||||
dilation_rate=dilation_rate,
|
|
||||||
strides=strides)
|
|
||||||
y2 = tf.nn.pool(
|
|
||||||
input=x,
|
|
||||||
window_shape=window_shape,
|
|
||||||
pooling_type=pooling_type,
|
|
||||||
padding=padding,
|
|
||||||
dilation_rate=dilation_rate,
|
|
||||||
strides=strides)
|
|
||||||
self.assertAllClose(y1, y2.eval(), rtol=1e-2, atol=1e-2)
|
self.assertAllClose(y1, y2.eval(), rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
def testPoolSimple(self):
|
def testPoolSimple(self):
|
||||||
@ -159,6 +152,32 @@ class PoolingTest(tf.test.TestCase):
|
|||||||
dilation_rate=[1, 1],
|
dilation_rate=[1, 1],
|
||||||
strides=[1, 2])
|
strides=[1, 2])
|
||||||
|
|
||||||
|
def testPool1D(self):
|
||||||
|
with self.test_session():
|
||||||
|
for padding in ["SAME", "VALID"]:
|
||||||
|
for pooling_type in ["MAX", "AVG"]:
|
||||||
|
for input_shape in [[2, 9, 2], [2, 10, 2]]:
|
||||||
|
for window_shape in [[1], [2], [3]]:
|
||||||
|
if padding != "SAME":
|
||||||
|
for dilation_rate in [[1], [2], [3]]:
|
||||||
|
self._test(
|
||||||
|
input_shape=input_shape,
|
||||||
|
window_shape=window_shape,
|
||||||
|
padding=padding,
|
||||||
|
pooling_type=pooling_type,
|
||||||
|
dilation_rate=dilation_rate,
|
||||||
|
strides=[1])
|
||||||
|
for strides in [[1], [2], [3]]:
|
||||||
|
if np.any(np.array(strides) > window_shape):
|
||||||
|
continue
|
||||||
|
self._test(
|
||||||
|
input_shape=input_shape,
|
||||||
|
window_shape=window_shape,
|
||||||
|
padding=padding,
|
||||||
|
pooling_type=pooling_type,
|
||||||
|
dilation_rate=[1],
|
||||||
|
strides=strides)
|
||||||
|
|
||||||
def testPool2D(self):
|
def testPool2D(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
for padding in ["SAME", "VALID"]:
|
for padding in ["SAME", "VALID"]:
|
||||||
@ -212,6 +231,40 @@ class PoolingTest(tf.test.TestCase):
|
|||||||
dilation_rate=[1, 1, 1],
|
dilation_rate=[1, 1, 1],
|
||||||
strides=strides)
|
strides=strides)
|
||||||
|
|
||||||
|
def testPoolNC(self):
|
||||||
|
if tf.test.is_gpu_available():
|
||||||
|
# "NC*" format is not currently supported on CPU.
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
for padding in ["SAME", "VALID"]:
|
||||||
|
self._test(input_shape=[2, 2, 9],
|
||||||
|
window_shape=[2],
|
||||||
|
padding=padding,
|
||||||
|
pooling_type="MAX",
|
||||||
|
strides=[1],
|
||||||
|
dilation_rate=[1],
|
||||||
|
data_format="NCW")
|
||||||
|
self._test(input_shape=[2, 2, 9],
|
||||||
|
window_shape=[2],
|
||||||
|
padding=padding,
|
||||||
|
pooling_type="MAX",
|
||||||
|
strides=[2],
|
||||||
|
dilation_rate=[1],
|
||||||
|
data_format="NCW")
|
||||||
|
self._test(input_shape=[2, 2, 7, 9],
|
||||||
|
window_shape=[2, 2],
|
||||||
|
padding=padding,
|
||||||
|
pooling_type="MAX",
|
||||||
|
strides=[1, 2],
|
||||||
|
dilation_rate=[1, 1],
|
||||||
|
data_format="NCHW")
|
||||||
|
self._test(input_shape=[2, 2, 7, 9],
|
||||||
|
window_shape=[2, 2],
|
||||||
|
padding="VALID",
|
||||||
|
pooling_type="MAX",
|
||||||
|
strides=[1, 1],
|
||||||
|
dilation_rate=[2, 2],
|
||||||
|
data_format="NCHW")
|
||||||
|
|
||||||
def _test_gradient(self, input_shape, **kwargs):
|
def _test_gradient(self, input_shape, **kwargs):
|
||||||
x_val = -np.arange(
|
x_val = -np.arange(
|
||||||
np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
|
np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
|
||||||
@ -224,6 +277,32 @@ class PoolingTest(tf.test.TestCase):
|
|||||||
err_tolerance = 1e-2
|
err_tolerance = 1e-2
|
||||||
self.assertLess(err, err_tolerance)
|
self.assertLess(err, err_tolerance)
|
||||||
|
|
||||||
|
def testGradient1D(self):
|
||||||
|
with self.test_session():
|
||||||
|
for padding in ["SAME", "VALID"]:
|
||||||
|
for pooling_type in ["AVG", "MAX"]:
|
||||||
|
for input_shape in [[2, 5, 2], [1, 4, 1]]:
|
||||||
|
for window_shape in [[1], [2]]:
|
||||||
|
if padding != "SAME":
|
||||||
|
for dilation_rate in [[1], [2]]:
|
||||||
|
self._test_gradient(
|
||||||
|
input_shape=input_shape,
|
||||||
|
window_shape=window_shape,
|
||||||
|
padding=padding,
|
||||||
|
pooling_type=pooling_type,
|
||||||
|
dilation_rate=dilation_rate,
|
||||||
|
strides=[1])
|
||||||
|
for strides in [[1], [2]]:
|
||||||
|
if np.any(np.array(strides) > window_shape):
|
||||||
|
continue
|
||||||
|
self._test(
|
||||||
|
input_shape=input_shape,
|
||||||
|
window_shape=window_shape,
|
||||||
|
padding=padding,
|
||||||
|
pooling_type=pooling_type,
|
||||||
|
dilation_rate=[1],
|
||||||
|
strides=strides)
|
||||||
|
|
||||||
def testGradient2D(self):
|
def testGradient2D(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
for padding in ["SAME", "VALID"]:
|
for padding in ["SAME", "VALID"]:
|
||||||
|
@ -42,7 +42,8 @@ from tensorflow.python.ops.gen_nn_ops import *
|
|||||||
local_response_normalization = gen_nn_ops.lrn
|
local_response_normalization = gen_nn_ops.lrn
|
||||||
|
|
||||||
|
|
||||||
def _non_atrous_convolution(input, filter, padding, strides=None, name=None): # pylint: disable=redefined-builtin
|
def _non_atrous_convolution(input, filter, padding, data_format=None, # pylint: disable=redefined-builtin
|
||||||
|
strides=None, name=None):
|
||||||
"""Computes sums of N-D convolutions (actually cross correlation).
|
"""Computes sums of N-D convolutions (actually cross correlation).
|
||||||
|
|
||||||
It is required that 1 <= N <= 3.
|
It is required that 1 <= N <= 3.
|
||||||
@ -51,12 +52,22 @@ def _non_atrous_convolution(input, filter, padding, strides=None, name=None): #
|
|||||||
extends the interface of this function with a `dilation_rate` parameter.
|
extends the interface of this function with a `dilation_rate` parameter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
||||||
input: Rank N+2 tensor of type T of shape
|
input: Rank N+2 tensor of type T of shape
|
||||||
`[batch_size] + input_spatial_shape + [in_channels]`.
|
`[batch_size] + input_spatial_shape + [in_channels]` if `data_format`
|
||||||
|
does not start with `"NC"`, or
|
||||||
|
`[batch_size, in_channels] + input_spatial_shape` if `data_format` starts
|
||||||
|
with `"NC"`.
|
||||||
filter: Rank N+2 tensor of type T of shape
|
filter: Rank N+2 tensor of type T of shape
|
||||||
`filter_spatial_shape + [in_channels, out_channels]`. Rank of either
|
`filter_spatial_shape + [in_channels, out_channels]`. Rank of either
|
||||||
`input` or `filter` must be known.
|
`input` or `filter` must be known.
|
||||||
padding: Padding method to use, must be either "VALID" or "SAME".
|
padding: Padding method to use, must be either "VALID" or "SAME".
|
||||||
|
data_format: A string or None. Specifies whether the channel dimension of
|
||||||
|
the `input` and output is the last dimension (default, or if `data_format`
|
||||||
|
does not start with "NC"), or the second dimension (if `data_format`
|
||||||
|
starts with "NC"). For N=1, the valid values are "NWC" (default) and
|
||||||
|
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
|
||||||
|
N=3, the valid value is "NDHWC".
|
||||||
strides: Sequence of N positive integers, defaults to `[1] * N`.
|
strides: Sequence of N positive integers, defaults to `[1] * N`.
|
||||||
name: Name prefix to use.
|
name: Name prefix to use.
|
||||||
|
|
||||||
@ -89,26 +100,50 @@ def _non_atrous_convolution(input, filter, padding, strides=None, name=None): #
|
|||||||
raise ValueError("len(strides)=%d, but should be %d" %
|
raise ValueError("len(strides)=%d, but should be %d" %
|
||||||
(len(strides), conv_dims))
|
(len(strides), conv_dims))
|
||||||
if conv_dims == 1:
|
if conv_dims == 1:
|
||||||
return conv1d(value=input,
|
# conv1d uses the 2-d data format names
|
||||||
filters=filter,
|
if data_format is None or data_format == "NWC":
|
||||||
stride=strides[0],
|
data_format_2d = "NHWC"
|
||||||
padding=padding,
|
elif data_format == "NCW":
|
||||||
name=scope)
|
data_format_2d = "NCHW"
|
||||||
|
else:
|
||||||
|
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
|
||||||
|
return conv1d(
|
||||||
|
value=input,
|
||||||
|
filters=filter,
|
||||||
|
stride=strides[0],
|
||||||
|
padding=padding,
|
||||||
|
data_format=data_format_2d,
|
||||||
|
name=scope)
|
||||||
elif conv_dims == 2:
|
elif conv_dims == 2:
|
||||||
return gen_nn_ops.conv2d(input=input,
|
if data_format is None or data_format == "NHWC":
|
||||||
filter=filter,
|
data_format = "NHWC"
|
||||||
strides=[1] + list(strides) + [1],
|
strides = [1] + list(strides) + [1]
|
||||||
padding=padding,
|
elif data_format == "NCHW":
|
||||||
name=name)
|
strides = [1, 1] + list(strides)
|
||||||
|
else:
|
||||||
|
raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
|
||||||
|
return gen_nn_ops.conv2d(
|
||||||
|
input=input,
|
||||||
|
filter=filter,
|
||||||
|
strides=strides,
|
||||||
|
padding=padding,
|
||||||
|
data_format=data_format,
|
||||||
|
name=name)
|
||||||
elif conv_dims == 3:
|
elif conv_dims == 3:
|
||||||
return gen_nn_ops.conv3d(input=input,
|
if data_format is None or data_format == "NDHWC":
|
||||||
filter=filter,
|
strides = [1] + list(strides) + [1]
|
||||||
strides=[1] + list(strides) + [1],
|
else:
|
||||||
padding=padding,
|
raise ValueError("data_format must be \"NDHWC\".")
|
||||||
name=name)
|
return gen_nn_ops.conv3d(
|
||||||
|
input=input,
|
||||||
|
filter=filter,
|
||||||
|
strides=strides,
|
||||||
|
padding=padding,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None): # pylint: disable=redefined-builtin
|
def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None, # pylint: disable=redefined-builtin
|
||||||
|
spatial_dims=None):
|
||||||
"""Performs `op` on the space-to-batch representation of `input`.
|
"""Performs `op` on the space-to-batch representation of `input`.
|
||||||
|
|
||||||
This has the effect of transforming sliding window operations into the
|
This has the effect of transforming sliding window operations into the
|
||||||
@ -122,19 +157,27 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
|
|||||||
Otherwise, it returns:
|
Otherwise, it returns:
|
||||||
|
|
||||||
batch_to_space_nd(
|
batch_to_space_nd(
|
||||||
op(space_to_batch_nd(input, dilation_rate, paddings),
|
op(space_to_batch_nd(input, adjusted_dilation_rate, adjusted_paddings),
|
||||||
num_spatial_dims,
|
num_spatial_dims,
|
||||||
"VALID")
|
"VALID")
|
||||||
dilation_rate,
|
adjusted_dilation_rate,
|
||||||
crops),
|
adjusted_crops),
|
||||||
|
|
||||||
where `paddings` and `crops` are int32 [num_spatial_dims, 2] tensors that
|
where:
|
||||||
depend on the value of `padding`:
|
|
||||||
|
adjusted_dilation_rate is an int64 tensor of shape [max(spatial_dims)],
|
||||||
|
adjusted_{paddings,crops} are int64 tensors of shape [max(spatial_dims), 2]
|
||||||
|
|
||||||
|
defined as follows:
|
||||||
|
|
||||||
|
We first define two int64 tensors `paddings` and `crops` of shape
|
||||||
|
`[num_spatial_dims, 2]` based on the value of `padding` and the spatial
|
||||||
|
dimensions of the `input`:
|
||||||
|
|
||||||
If `padding = "VALID"`, then:
|
If `padding = "VALID"`, then:
|
||||||
|
|
||||||
paddings, crops = required_space_to_batch_paddings(
|
paddings, crops = required_space_to_batch_paddings(
|
||||||
input_shape[1:num_spatial_dims+1],
|
input_shape[spatial_dims],
|
||||||
dilation_rate)
|
dilation_rate)
|
||||||
|
|
||||||
If `padding = "SAME"`, then:
|
If `padding = "SAME"`, then:
|
||||||
@ -143,10 +186,30 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
|
|||||||
filter_shape + (filter_shape - 1) * (dilation_rate - 1)
|
filter_shape + (filter_shape - 1) * (dilation_rate - 1)
|
||||||
|
|
||||||
paddings, crops = required_space_to_batch_paddings(
|
paddings, crops = required_space_to_batch_paddings(
|
||||||
input_shape[1:num_spatial_dims+1],
|
input_shape[spatial_dims],
|
||||||
|
dilation_rate,
|
||||||
[(dilated_filter_shape - 1) // 2,
|
[(dilated_filter_shape - 1) // 2,
|
||||||
dilated_filter_shape - 1 - (dilated_filter_shape - 1) // 2])
|
dilated_filter_shape - 1 - (dilated_filter_shape - 1) // 2])
|
||||||
|
|
||||||
|
Because `space_to_batch_nd` and `batch_to_space_nd` assume that the spatial
|
||||||
|
dimensions are contiguous starting at the second dimension, but the specified
|
||||||
|
`spatial_dims` may not be, we must adjust `dilation_rate`, `paddings` and
|
||||||
|
`crops` in order to be usable with these operations. For a given dimension,
|
||||||
|
if the block size is 1, and both the starting and ending padding and crop
|
||||||
|
amounts are 0, then space_to_batch_nd effectively leaves that dimension alone,
|
||||||
|
which is what is needed for dimensions not part of `spatial_dims`.
|
||||||
|
Furthermore, `space_to_batch_nd` and `batch_to_space_nd` handle this case
|
||||||
|
efficiently for any number of leading and trailing dimensions.
|
||||||
|
|
||||||
|
For 0 <= i < len(spatial_dims), we assign:
|
||||||
|
|
||||||
|
adjusted_dilation_rate[spatial_dims[i] - 1] = dilation_rate[i]
|
||||||
|
adjusted_paddings[spatial_dims[i] - 1, :] = paddings[i, :]
|
||||||
|
adjusted_crops[spatial_dims[i] - 1, :] = crops[i, :]
|
||||||
|
|
||||||
|
All unassigned values of `adjusted_dilation_rate` default to 1, while all
|
||||||
|
unassigned values of `adjusted_paddings` and `adjusted_crops` default to 0.
|
||||||
|
|
||||||
Note in the case that `dilation_rate` is not uniformly 1, specifying "VALID"
|
Note in the case that `dilation_rate` is not uniformly 1, specifying "VALID"
|
||||||
padding is equivalent to specifying `padding = "SAME"` with a filter_shape of
|
padding is equivalent to specifying `padding = "SAME"` with a filter_shape of
|
||||||
`[1]*N`.
|
`[1]*N`.
|
||||||
@ -189,19 +252,23 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
|
|||||||
net = with_space_to_batch(net, dilation_rate, "VALID", combined_op)
|
net = with_space_to_batch(net, dilation_rate, "VALID", combined_op)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input: Tensor of rank >= 1 + num_spatial_dims.
|
input: Tensor of rank > max(spatial_dims).
|
||||||
dilation_rate: int32 Tensor of *known* shape [num_spatial_dims].
|
dilation_rate: int32 Tensor of *known* shape [num_spatial_dims].
|
||||||
padding: str constant equal to "VALID" or "SAME"
|
padding: str constant equal to "VALID" or "SAME"
|
||||||
op: Function that maps (input, num_spatial_dims, padding) -> output
|
op: Function that maps (input, num_spatial_dims, padding) -> output
|
||||||
filter_shape: If padding = "SAME", specifies the shape of the convolution
|
filter_shape: If padding = "SAME", specifies the shape of the convolution
|
||||||
kernel/pooling window as an integer Tensor of shape [>=num_spatial_dims].
|
kernel/pooling window as an integer Tensor of shape [>=num_spatial_dims].
|
||||||
If padding = "VALID", filter_shape is ignored and need not be specified.
|
If padding = "VALID", filter_shape is ignored and need not be specified.
|
||||||
|
spatial_dims: Monotonically increasing sequence of `num_spatial_dims`
|
||||||
|
integers (which are >= 1) specifying the spatial dimensions of `input`
|
||||||
|
and output. Defaults to: `range(1, num_spatial_dims+1)`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The output Tensor as described above.
|
The output Tensor as described above.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if padding is invalid or the arguments are incompatible.
|
ValueError: if `padding` is invalid or the arguments are incompatible.
|
||||||
|
ValueError: if `spatial_dims` are invalid.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
input = ops.convert_to_tensor(input, name="input")
|
input = ops.convert_to_tensor(input, name="input")
|
||||||
@ -218,18 +285,27 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
|
|||||||
|
|
||||||
num_spatial_dims = rate_shape[0].value
|
num_spatial_dims = rate_shape[0].value
|
||||||
|
|
||||||
|
if spatial_dims is None:
|
||||||
|
spatial_dims = range(1, num_spatial_dims + 1)
|
||||||
|
orig_spatial_dims = list(spatial_dims)
|
||||||
|
spatial_dims = sorted(set(int(x) for x in orig_spatial_dims))
|
||||||
|
if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
|
||||||
|
raise ValueError(
|
||||||
|
"spatial_dims must be a montonically increasing sequence of positive integers") # pylint: disable=line-too-long
|
||||||
|
last_spatial_dim = spatial_dims[-1]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
input.get_shape().with_rank_at_least(num_spatial_dims + 1)
|
input.get_shape().with_rank_at_least(last_spatial_dim + 1)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
ValueError("input tensor must have rank %d at least" %
|
ValueError("input tensor must have rank %d at least" %
|
||||||
(num_spatial_dims + 1))
|
(last_spatial_dim + 1))
|
||||||
|
|
||||||
const_rate = tensor_util.constant_value(dilation_rate)
|
const_rate = tensor_util.constant_value(dilation_rate)
|
||||||
rate_or_const_rate = dilation_rate
|
rate_or_const_rate = dilation_rate
|
||||||
if const_rate is not None:
|
if const_rate is not None:
|
||||||
rate_or_const_rate = const_rate
|
rate_or_const_rate = const_rate
|
||||||
if np.any(const_rate < 1):
|
if np.any(const_rate < 1):
|
||||||
raise ValueError("rate must be positive")
|
raise ValueError("dilation_rate must be positive")
|
||||||
if np.all(const_rate == 1):
|
if np.all(const_rate == 1):
|
||||||
return op(input, num_spatial_dims, padding)
|
return op(input, num_spatial_dims, padding)
|
||||||
|
|
||||||
@ -266,26 +342,88 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
|
|||||||
raise ValueError("Invalid padding method %r" % padding)
|
raise ValueError("Invalid padding method %r" % padding)
|
||||||
|
|
||||||
# Handle input whose shape is unknown during graph creation.
|
# Handle input whose shape is unknown during graph creation.
|
||||||
if input.get_shape().is_fully_defined():
|
input_spatial_shape = None
|
||||||
input_shape = np.array(input.get_shape().as_list())
|
if input.get_shape().ndims is not None:
|
||||||
else:
|
input_shape_list = input.get_shape().as_list()
|
||||||
input_shape = array_ops.shape(input)
|
input_spatial_shape = [input_shape_list[i] for i in spatial_dims]
|
||||||
|
if input_spatial_shape is None or None in input_spatial_shape:
|
||||||
|
input_spatial_shape = array_ops.gather(array_ops.shape(input), spatial_dims)
|
||||||
|
|
||||||
input_spatial_shape = input_shape[1:num_spatial_dims+1]
|
|
||||||
paddings, crops = array_ops.required_space_to_batch_paddings(
|
paddings, crops = array_ops.required_space_to_batch_paddings(
|
||||||
input_shape=input_spatial_shape,
|
input_shape=input_spatial_shape,
|
||||||
base_paddings=base_paddings,
|
base_paddings=base_paddings,
|
||||||
block_shape=dilation_rate)
|
block_shape=dilation_rate)
|
||||||
|
|
||||||
input_converted = array_ops.space_to_batch_nd(input=input,
|
def adjust(orig, fill_value):
|
||||||
block_shape=dilation_rate,
|
"""Returns an `adjusted` version of `orig` based on `spatial_dims`.
|
||||||
paddings=paddings)
|
|
||||||
|
Tensor of the same type as `orig` and with shape
|
||||||
|
`[max(spatial_dims), ...]` where:
|
||||||
|
|
||||||
|
adjusted[spatial_dims[i] - 1, ...] = orig[i, ...]
|
||||||
|
|
||||||
|
for 0 <= i < len(spatial_dims), and
|
||||||
|
|
||||||
|
adjusted[j, ...] = fill_value
|
||||||
|
|
||||||
|
for j != spatial_dims[i] - 1 for some i.
|
||||||
|
|
||||||
|
If `orig` is a constant value, then the result will be a constant value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orig: Tensor of rank > max(spatial_dims).
|
||||||
|
fill_value: Numpy scalar (of same data type as `orig) specifying the fill
|
||||||
|
value for non-spatial dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`adjusted` tensor.
|
||||||
|
"""
|
||||||
|
fill_dims = orig.get_shape().as_list()[1:]
|
||||||
|
dtype = orig.dtype.as_numpy_dtype
|
||||||
|
parts = []
|
||||||
|
const_orig = tensor_util.constant_value(orig)
|
||||||
|
const_or_orig = const_orig if const_orig is not None else orig
|
||||||
|
prev_spatial_dim = 0
|
||||||
|
i = 0
|
||||||
|
while i < len(spatial_dims):
|
||||||
|
start_i = i
|
||||||
|
start_spatial_dim = spatial_dims[i]
|
||||||
|
if start_spatial_dim > 1:
|
||||||
|
# Fill in any gap from the previous spatial dimension (or dimension 1 if
|
||||||
|
# this is the first spatial dimension) with `fill_value`.
|
||||||
|
parts.append(
|
||||||
|
np.full(
|
||||||
|
[start_spatial_dim - 1 - prev_spatial_dim] + fill_dims,
|
||||||
|
fill_value,
|
||||||
|
dtype=dtype))
|
||||||
|
# Find the largest value of i such that:
|
||||||
|
# [spatial_dims[start_i], ..., spatial_dims[i]]
|
||||||
|
# == [start_spatial_dim, ..., start_spatial_dim + i - start_i],
|
||||||
|
# i.e. the end of a contiguous group of spatial dimensions.
|
||||||
|
while (i + 1 < len(spatial_dims) and
|
||||||
|
spatial_dims[i + 1] == spatial_dims[i] + 1):
|
||||||
|
i += 1
|
||||||
|
parts.append(const_or_orig[start_i:i + 1])
|
||||||
|
prev_spatial_dim = spatial_dims[i]
|
||||||
|
i += 1
|
||||||
|
if const_orig is not None:
|
||||||
|
return np.concatenate(parts)
|
||||||
|
else:
|
||||||
|
return array_ops.concat(0, parts)
|
||||||
|
|
||||||
|
dilation_rate = adjust(dilation_rate, 1)
|
||||||
|
paddings = adjust(paddings, 0)
|
||||||
|
crops = adjust(crops, 0)
|
||||||
|
|
||||||
|
input_converted = array_ops.space_to_batch_nd(
|
||||||
|
input=input,
|
||||||
|
block_shape=dilation_rate,
|
||||||
|
paddings=paddings)
|
||||||
|
|
||||||
result = op(input_converted, num_spatial_dims, "VALID")
|
result = op(input_converted, num_spatial_dims, "VALID")
|
||||||
|
|
||||||
result_converted = array_ops.batch_to_space_nd(input=result,
|
result_converted = array_ops.batch_to_space_nd(
|
||||||
block_shape=dilation_rate,
|
input=result, block_shape=dilation_rate, crops=crops)
|
||||||
crops=crops)
|
|
||||||
return result_converted
|
return result_converted
|
||||||
|
|
||||||
|
|
||||||
@ -333,7 +471,8 @@ def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate):
|
|||||||
|
|
||||||
|
|
||||||
def convolution(input, filter, # pylint: disable=redefined-builtin
|
def convolution(input, filter, # pylint: disable=redefined-builtin
|
||||||
padding, strides=None, dilation_rate=None, name=None):
|
padding, strides=None, dilation_rate=None,
|
||||||
|
name=None, data_format=None):
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
"""Computes sums of N-D convolutions (actually cross-correlation).
|
"""Computes sums of N-D convolutions (actually cross-correlation).
|
||||||
|
|
||||||
@ -343,7 +482,8 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
|
|||||||
the optional `dilation_rate` parameter. Currently, however, output striding
|
the optional `dilation_rate` parameter. Currently, however, output striding
|
||||||
is not supported for atrous convolutions.
|
is not supported for atrous convolutions.
|
||||||
|
|
||||||
Specifically, given rank (N+2) `input` Tensor of shape
|
Specifically, in the case that `data_format` does not start with "NC", given
|
||||||
|
a rank (N+2) `input` Tensor of shape
|
||||||
|
|
||||||
[num_batches,
|
[num_batches,
|
||||||
input_spatial_shape[0],
|
input_spatial_shape[0],
|
||||||
@ -368,23 +508,34 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
|
|||||||
|
|
||||||
sum_{z[0], ..., z[N-1], q}
|
sum_{z[0], ..., z[N-1], q}
|
||||||
|
|
||||||
filters[z[0], ..., z[N-1], q, k] *
|
filter[z[0], ..., z[N-1], q, k] *
|
||||||
padded_input[b,
|
padded_input[b,
|
||||||
x[0]*strides[0] + dilation_rate[0]*z[0],
|
x[0]*strides[0] + dilation_rate[0]*z[0],
|
||||||
...,
|
...,
|
||||||
x[N-1]*strides[N-1] + dilation_rate[N-1]*z[N-1],
|
x[N-1]*strides[N-1] + dilation_rate[N-1]*z[N-1],
|
||||||
q],
|
q]
|
||||||
|
|
||||||
where `padded_input` is obtained by zero padding the input using an effective
|
where `padded_input` is obtained by zero padding the input using an effective
|
||||||
spatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` and
|
spatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` and
|
||||||
output striding `strides` as described in the
|
output striding `strides` as described in the
|
||||||
[comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution).
|
[comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution).
|
||||||
|
|
||||||
|
In the case that `data_format` does start with `"NC"`, the `input` and output
|
||||||
|
(but not the `filter`) are simply transposed as follows:
|
||||||
|
|
||||||
|
convolution(input, data_format, **kwargs) =
|
||||||
|
tf.transpose(convolution(tf.transpose(input, [0] + range(2,N+2) + [1]),
|
||||||
|
**kwargs),
|
||||||
|
[0, N+1] + range(1, N+1))
|
||||||
|
|
||||||
It is required that 1 <= N <= 3.
|
It is required that 1 <= N <= 3.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input: An N-D `Tensor` of type `T`, of shape
|
input: An N-D `Tensor` of type `T`, of shape
|
||||||
`[batch_size] + input_spatial_shape + [in_channels]`.
|
`[batch_size] + input_spatial_shape + [in_channels]` if data_format does
|
||||||
|
not start with "NC" (default), or
|
||||||
|
`[batch_size, in_channels] + input_spatial_shape` if data_format starts
|
||||||
|
with "NC".
|
||||||
filter: An N-D `Tensor` with the same type as `input` and shape
|
filter: An N-D `Tensor` with the same type as `input` and shape
|
||||||
`spatial_filter_shape + [in_channels, out_channels]`.
|
`spatial_filter_shape + [in_channels, out_channels]`.
|
||||||
padding: A string, either `"VALID"` or `"SAME"`. The padding algorithm.
|
padding: A string, either `"VALID"` or `"SAME"`. The padding algorithm.
|
||||||
@ -400,12 +551,23 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
|
|||||||
filter in each spatial dimension i. If any value of dilation_rate is > 1,
|
filter in each spatial dimension i. If any value of dilation_rate is > 1,
|
||||||
then all values of strides must be 1.
|
then all values of strides must be 1.
|
||||||
name: Optional name for the returned tensor.
|
name: Optional name for the returned tensor.
|
||||||
|
data_format: A string or None. Specifies whether the channel dimension of
|
||||||
|
the `input` and output is the last dimension (default, or if `data_format`
|
||||||
|
does not start with "NC"), or the second dimension (if `data_format`
|
||||||
|
starts with "NC"). For N=1, the valid values are "NWC" (default) and
|
||||||
|
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
|
||||||
|
N=3, the valid value is "NDHWC".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Tensor` with the same type as `value` of shape
|
A `Tensor` with the same type as `input` of shape
|
||||||
|
|
||||||
`[batch_size] + output_spatial_shape + [out_channels]`,
|
`[batch_size] + output_spatial_shape + [out_channels]`
|
||||||
|
|
||||||
|
if data_format is None or does not start with "NC", or
|
||||||
|
|
||||||
|
`[batch_size, out_channels] + output_spatial_shape`
|
||||||
|
|
||||||
|
if data_format starts with "NC",
|
||||||
where `output_spatial_shape` depends on the value of `padding`.
|
where `output_spatial_shape` depends on the value of `padding`.
|
||||||
|
|
||||||
If padding == "SAME":
|
If padding == "SAME":
|
||||||
@ -418,8 +580,8 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
|
|||||||
/ strides[i]).
|
/ strides[i]).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input/output depth does not match `filter` shape, or if
|
ValueError: If input/output depth does not match `filter` shape, if padding
|
||||||
padding is other than `"VALID"` or `"SAME"`.
|
is other than `"VALID"` or `"SAME"`, or if data_format is invalid.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
@ -444,12 +606,19 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
ValueError("filter tensor must have rank %d" % (num_spatial_dims + 2))
|
ValueError("filter tensor must have rank %d" % (num_spatial_dims + 2))
|
||||||
|
|
||||||
if not input.get_shape()[num_spatial_dims + 1].is_compatible_with(
|
if data_format is None or not data_format.startswith("NC"):
|
||||||
filter.get_shape()[num_spatial_dims]):
|
input_channels_dim = input.get_shape()[num_spatial_dims + 1]
|
||||||
|
spatial_dims = range(1, num_spatial_dims+1)
|
||||||
|
else:
|
||||||
|
input_channels_dim = input.get_shape()[1]
|
||||||
|
spatial_dims = range(2, num_spatial_dims+2)
|
||||||
|
|
||||||
|
if not input_channels_dim.is_compatible_with(filter.get_shape()[
|
||||||
|
num_spatial_dims]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"number of input channels does not match corresponding dimension of filter, "
|
"number of input channels does not match corresponding dimension of filter, "
|
||||||
"{} != {}".format(input.get_shape()[num_spatial_dims + 1],
|
"{} != {}".format(input_channels_dim, filter.get_shape()[
|
||||||
filter.get_shape()[num_spatial_dims]))
|
num_spatial_dims]))
|
||||||
|
|
||||||
strides, dilation_rate = _get_strides_and_dilation_rate(
|
strides, dilation_rate = _get_strides_and_dilation_rate(
|
||||||
num_spatial_dims, strides, dilation_rate)
|
num_spatial_dims, strides, dilation_rate)
|
||||||
@ -459,12 +628,14 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
|
|||||||
input=input_converted,
|
input=input_converted,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
|
data_format=data_format,
|
||||||
strides=strides,
|
strides=strides,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
return with_space_to_batch(
|
return with_space_to_batch(
|
||||||
input=input,
|
input=input,
|
||||||
filter_shape=array_ops.shape(filter),
|
filter_shape=array_ops.shape(filter),
|
||||||
|
spatial_dims=spatial_dims,
|
||||||
dilation_rate=dilation_rate,
|
dilation_rate=dilation_rate,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
op=op)
|
op=op)
|
||||||
@ -476,11 +647,12 @@ def pool(input, # pylint: disable=redefined-builtin
|
|||||||
padding,
|
padding,
|
||||||
dilation_rate=None,
|
dilation_rate=None,
|
||||||
strides=None,
|
strides=None,
|
||||||
name=None):
|
name=None,
|
||||||
|
data_format=None):
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
"""Performs an N-D pooling operation.
|
"""Performs an N-D pooling operation.
|
||||||
|
|
||||||
Computes for
|
In the case that `data_format` does not start with "NC", computes for
|
||||||
0 <= b < batch_size,
|
0 <= b < batch_size,
|
||||||
0 <= x[i] < output_spatial_shape[i],
|
0 <= x[i] < output_spatial_shape[i],
|
||||||
0 <= c < num_channels:
|
0 <= c < num_channels:
|
||||||
@ -498,10 +670,20 @@ def pool(input, # pylint: disable=redefined-builtin
|
|||||||
[comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution).
|
[comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution).
|
||||||
The reduction never includes out-of-bounds positions.
|
The reduction never includes out-of-bounds positions.
|
||||||
|
|
||||||
|
In the case that `data_format` starts with `"NC"`, the `input` and output are
|
||||||
|
simply transposed as follows:
|
||||||
|
|
||||||
|
pool(input, data_format, **kwargs) =
|
||||||
|
tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
|
||||||
|
**kwargs),
|
||||||
|
[0, N+1] + range(1, N+1))
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input: Tensor of rank N+2, of shape
|
input: Tensor of rank N+2, of shape
|
||||||
[batch_size] + input_spatial_shape + [num_channels].
|
`[batch_size] + input_spatial_shape + [num_channels]` if data_format does
|
||||||
Pooling happens over the spatial dimensions only.
|
not start with "NC" (default), or
|
||||||
|
`[batch_size, num_channels] + input_spatial_shape` if data_format starts
|
||||||
|
with "NC". Pooling happens over the spatial dimensions only.
|
||||||
window_shape: Sequence of N ints >= 1.
|
window_shape: Sequence of N ints >= 1.
|
||||||
pooling_type: Specifies pooling operation, must be "AVG" or "MAX".
|
pooling_type: Specifies pooling operation, must be "AVG" or "MAX".
|
||||||
padding: The padding algorithm, must be "SAME" or "VALID".
|
padding: The padding algorithm, must be "SAME" or "VALID".
|
||||||
@ -513,10 +695,22 @@ def pool(input, # pylint: disable=redefined-builtin
|
|||||||
If any value of strides is > 1, then all values of dilation_rate must be
|
If any value of strides is > 1, then all values of dilation_rate must be
|
||||||
1.
|
1.
|
||||||
name: Optional. Name of the op.
|
name: Optional. Name of the op.
|
||||||
|
data_format: A string or None. Specifies whether the channel dimension of
|
||||||
|
the `input` and output is the last dimension (default, or if `data_format`
|
||||||
|
does not start with "NC"), or the second dimension (if `data_format`
|
||||||
|
starts with "NC"). For N=1, the valid values are "NWC" (default) and
|
||||||
|
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
|
||||||
|
N=3, the valid value is "NDHWC".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor of rank N+2, of shape
|
Tensor of rank N+2, of shape
|
||||||
[batch_size] + output_spatial_shape + [num_channels],
|
[batch_size] + output_spatial_shape + [num_channels]
|
||||||
|
|
||||||
|
if data_format is None or does not start with "NC", or
|
||||||
|
|
||||||
|
[batch_size, num_channels] + output_spatial_shape
|
||||||
|
|
||||||
|
if data_format starts with "NC",
|
||||||
where `output_spatial_shape` depends on the value of padding:
|
where `output_spatial_shape` depends on the value of padding:
|
||||||
|
|
||||||
If padding = "SAME":
|
If padding = "SAME":
|
||||||
@ -536,8 +730,8 @@ def pool(input, # pylint: disable=redefined-builtin
|
|||||||
input = ops.convert_to_tensor(input, name="input")
|
input = ops.convert_to_tensor(input, name="input")
|
||||||
|
|
||||||
num_spatial_dims = len(window_shape)
|
num_spatial_dims = len(window_shape)
|
||||||
if num_spatial_dims < 2 or num_spatial_dims > 3:
|
if num_spatial_dims < 1 or num_spatial_dims > 3:
|
||||||
raise ValueError("It is required that 2 <= num_spatial_dims <= 3.")
|
raise ValueError("It is required that 1 <= num_spatial_dims <= 3.")
|
||||||
|
|
||||||
input.get_shape().with_rank(num_spatial_dims + 2)
|
input.get_shape().with_rank(num_spatial_dims + 2)
|
||||||
|
|
||||||
@ -553,8 +747,10 @@ def pool(input, # pylint: disable=redefined-builtin
|
|||||||
"strides > window_shape not supported due to inconsistency between "
|
"strides > window_shape not supported due to inconsistency between "
|
||||||
"CPU and GPU implementations")
|
"CPU and GPU implementations")
|
||||||
|
|
||||||
pooling_ops = {("MAX", 2): max_pool,
|
pooling_ops = {("MAX", 1): max_pool,
|
||||||
|
("MAX", 2): max_pool,
|
||||||
("MAX", 3): max_pool3d, # pylint: disable=undefined-variable
|
("MAX", 3): max_pool3d, # pylint: disable=undefined-variable
|
||||||
|
("AVG", 1): avg_pool,
|
||||||
("AVG", 2): avg_pool,
|
("AVG", 2): avg_pool,
|
||||||
("AVG", 3): avg_pool3d, # pylint: disable=undefined-variable
|
("AVG", 3): avg_pool3d, # pylint: disable=undefined-variable
|
||||||
}
|
}
|
||||||
@ -563,18 +759,52 @@ def pool(input, # pylint: disable=redefined-builtin
|
|||||||
raise ValueError("%d-D %s pooling is not supported." %
|
raise ValueError("%d-D %s pooling is not supported." %
|
||||||
(op_key[1], op_key[0]))
|
(op_key[1], op_key[0]))
|
||||||
|
|
||||||
def op(converted_input, _, converted_padding):
|
if data_format is None or not data_format.startswith("NC"):
|
||||||
return pooling_ops[op_key](converted_input,
|
adjusted_window_shape = [1] + list(window_shape) + [1]
|
||||||
[1] + list(window_shape) + [1],
|
adjusted_strides = [1] + list(strides) + [1]
|
||||||
[1] + list(strides) + [1],
|
spatial_dims = range(1, num_spatial_dims + 1)
|
||||||
converted_padding,
|
else:
|
||||||
name=scope)
|
adjusted_window_shape = [1, 1] + list(window_shape)
|
||||||
|
adjusted_strides = [1, 1] + list(strides)
|
||||||
|
spatial_dims = range(2, num_spatial_dims + 2)
|
||||||
|
|
||||||
return with_space_to_batch(input=input,
|
if num_spatial_dims == 3:
|
||||||
dilation_rate=dilation_rate,
|
if data_format is not None and data_format != "NDHWC":
|
||||||
padding=padding,
|
raise ValueError("data_format must be \"NDHWC\".")
|
||||||
op=op,
|
data_format_kwargs = dict()
|
||||||
filter_shape=window_shape)
|
elif num_spatial_dims == 1:
|
||||||
|
if data_format is None or data_format == "NWC":
|
||||||
|
data_format_kwargs = dict(data_format="NHWC")
|
||||||
|
elif data_format == "NCW":
|
||||||
|
data_format_kwargs = dict(data_format="NCHW")
|
||||||
|
else:
|
||||||
|
raise ValueError("data_format must be either \"NWC\" or \"NCW\".")
|
||||||
|
adjusted_window_shape = [1] + adjusted_window_shape
|
||||||
|
adjusted_strides = [1] + adjusted_strides
|
||||||
|
else:
|
||||||
|
data_format_kwargs = dict(data_format=data_format)
|
||||||
|
|
||||||
|
def op(converted_input, _, converted_padding): # pylint: disable=missing-docstring
|
||||||
|
if num_spatial_dims == 1:
|
||||||
|
converted_input = array_ops.expand_dims(converted_input,
|
||||||
|
spatial_dims[0])
|
||||||
|
result = pooling_ops[op_key](converted_input,
|
||||||
|
adjusted_window_shape,
|
||||||
|
adjusted_strides,
|
||||||
|
converted_padding,
|
||||||
|
name=scope,
|
||||||
|
**data_format_kwargs)
|
||||||
|
if num_spatial_dims == 1:
|
||||||
|
result = array_ops.squeeze(result, [spatial_dims[0]])
|
||||||
|
return result
|
||||||
|
|
||||||
|
return with_space_to_batch(
|
||||||
|
input=input,
|
||||||
|
dilation_rate=dilation_rate,
|
||||||
|
padding=padding,
|
||||||
|
op=op,
|
||||||
|
spatial_dims=spatial_dims,
|
||||||
|
filter_shape=window_shape)
|
||||||
|
|
||||||
|
|
||||||
def atrous_conv2d(value, filters, rate, padding, name=None):
|
def atrous_conv2d(value, filters, rate, padding, name=None):
|
||||||
@ -1794,19 +2024,27 @@ def conv1d(value, filters, stride, padding,
|
|||||||
name=None):
|
name=None):
|
||||||
"""Computes a 1-D convolution given 3-D input and filter tensors.
|
"""Computes a 1-D convolution given 3-D input and filter tensors.
|
||||||
|
|
||||||
Given an input tensor of shape [batch, in_width, in_channels]
|
Given an input tensor of shape
|
||||||
|
[batch, in_width, in_channels]
|
||||||
|
if data_format is "NHWC", or
|
||||||
|
[batch, in_channels, in_width]
|
||||||
|
if data_format is "NCHW",
|
||||||
and a filter / kernel tensor of shape
|
and a filter / kernel tensor of shape
|
||||||
[filter_width, in_channels, out_channels], this op reshapes
|
[filter_width, in_channels, out_channels], this op reshapes
|
||||||
the arguments to pass them to conv2d to perform the equivalent
|
the arguments to pass them to conv2d to perform the equivalent
|
||||||
convolution operation.
|
convolution operation.
|
||||||
|
|
||||||
Internally, this op reshapes the input tensors and invokes
|
Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
|
||||||
`tf.nn.conv2d`. A tensor of shape [batch, in_width, in_channels]
|
For example, if `data_format` does not start with "NC", a tensor of shape
|
||||||
is reshaped to [batch, 1, in_width, in_channels], and the filter
|
[batch, in_width, in_channels]
|
||||||
is reshaped to [1, filter_width, in_channels, out_channels].
|
is reshaped to
|
||||||
The result is then reshaped back to [batch, out_width, out_channels]
|
[batch, 1, in_width, in_channels],
|
||||||
(where out_width is a function of the stride and padding as in
|
and the filter is reshaped to
|
||||||
conv2d) and returned to the caller.
|
[1, filter_width, in_channels, out_channels].
|
||||||
|
The result is then reshaped back to
|
||||||
|
[batch, out_width, out_channels]
|
||||||
|
(where out_width is a function of the stride and padding as in conv2d) and
|
||||||
|
returned to the caller.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value: A 3D `Tensor`. Must be of type `float32` or `float64`.
|
value: A 3D `Tensor`. Must be of type `float32` or `float64`.
|
||||||
@ -1823,16 +2061,27 @@ def conv1d(value, filters, stride, padding,
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Tensor`. Has the same type as input.
|
A `Tensor`. Has the same type as input.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `data_format` is invalid.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "conv1d", [value, filters]) as name:
|
with ops.name_scope(name, "conv1d", [value, filters]) as name:
|
||||||
# Reshape the input tensor to [batch, 1, in_width, in_channels]
|
# Reshape the input tensor to [batch, 1, in_width, in_channels]
|
||||||
value = array_ops.expand_dims(value, 1)
|
if data_format is None or data_format == "NHWC":
|
||||||
# And reshape the filter to [1, filter_width, in_channels, out_channels]
|
data_format = "NHWC"
|
||||||
|
spatial_start_dim = 1
|
||||||
|
strides = [1, 1, stride, 1]
|
||||||
|
elif data_format == "NCHW":
|
||||||
|
spatial_start_dim = 2
|
||||||
|
strides = [1, 1, 1, stride]
|
||||||
|
else:
|
||||||
|
raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
|
||||||
|
value = array_ops.expand_dims(value, spatial_start_dim)
|
||||||
filters = array_ops.expand_dims(filters, 0)
|
filters = array_ops.expand_dims(filters, 0)
|
||||||
result = gen_nn_ops.conv2d(value, filters, [1, 1, stride, 1], padding,
|
result = gen_nn_ops.conv2d(value, filters, strides, padding,
|
||||||
use_cudnn_on_gpu=use_cudnn_on_gpu,
|
use_cudnn_on_gpu=use_cudnn_on_gpu,
|
||||||
data_format=data_format)
|
data_format=data_format)
|
||||||
return array_ops.squeeze(result, [1])
|
return array_ops.squeeze(result, [spatial_start_dim])
|
||||||
|
|
||||||
|
|
||||||
ops.RegisterShape("Dilation2D")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("Dilation2D")(common_shapes.call_cpp_shape_fn)
|
||||||
|
Loading…
Reference in New Issue
Block a user