Add data_format support for N-D convolution and pooling.
This adds support for "NC*" data layouts for N-D convolution and pooling (including atrous convolution and pooling); previously, only "N*C" data lyaouts were supported. This also adds support for 1-D pooling (by forwarding to the 2-D implementation), and fixes the broken data_format support in conv1d. Change: 136556507
This commit is contained in:
parent
45010d6a49
commit
8fa9b949dc
@ -48,15 +48,15 @@ def upsample_filters(filters, rate):
|
||||
|
||||
class AtrousConvolutionTest(tf.test.TestCase):
|
||||
|
||||
def _test_atrous_convolution(self, input_shape, filter_shape, padding,
|
||||
dilation_rate):
|
||||
def _test_atrous_convolution(self, input_shape, filter_shape, dilation_rate,
|
||||
**kwargs):
|
||||
filters = np.arange(
|
||||
np.prod(filter_shape), dtype=np.float32).reshape(filter_shape)
|
||||
filters_upsampled = upsample_filters(filters, dilation_rate)
|
||||
x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape)
|
||||
y1 = tf.nn.convolution(
|
||||
input=x, filter=filters, padding=padding, dilation_rate=dilation_rate)
|
||||
y2 = tf.nn.convolution(input=x, filter=filters_upsampled, padding=padding)
|
||||
input=x, filter=filters, dilation_rate=dilation_rate, **kwargs)
|
||||
y2 = tf.nn.convolution(input=x, filter=filters_upsampled, **kwargs)
|
||||
self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-2, atol=1e-2)
|
||||
|
||||
def testAtrousConvolution2D(self):
|
||||
@ -99,6 +99,24 @@ class AtrousConvolutionTest(tf.test.TestCase):
|
||||
padding=padding,
|
||||
dilation_rate=[rate])
|
||||
|
||||
def testAtrousConvolutionNC(self):
|
||||
if tf.test.is_gpu_available():
|
||||
# "NCW" and "NCHW" formats are not currently supported on CPU.
|
||||
with self.test_session(use_gpu=True):
|
||||
for padding in ["SAME", "VALID"]:
|
||||
self._test_atrous_convolution(
|
||||
input_shape=[2, 2, 9],
|
||||
padding=padding,
|
||||
filter_shape=[3, 2, 2],
|
||||
dilation_rate=[2],
|
||||
data_format="NCW")
|
||||
self._test_atrous_convolution(
|
||||
input_shape=[2, 2, 9, 5],
|
||||
padding=padding,
|
||||
filter_shape=[3, 3, 2, 2],
|
||||
dilation_rate=[2, 1],
|
||||
data_format="NCHW")
|
||||
|
||||
def testAtrousSequence(self):
|
||||
"""Tests optimization of sequence of atrous convolutions.
|
||||
|
||||
|
@ -89,7 +89,7 @@ def pool_direct_single_axis(input, # pylint: disable=redefined-builtin
|
||||
|
||||
|
||||
def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=redefined-builtin
|
||||
dilation_rate, strides):
|
||||
dilation_rate, strides, data_format=None):
|
||||
"""Numpy implementation of pooling.
|
||||
|
||||
This is intended for testing only, and therefore isn't particularly efficient.
|
||||
@ -103,6 +103,8 @@ def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=r
|
||||
padding: either "SAME" or "VALID".
|
||||
dilation_rate: Sequence of N ints >= 1.
|
||||
strides: Sequence of N ints >= 1.
|
||||
data_format: If specified and starts with "NC", indicates that second
|
||||
dimension, rather than the last dimension, specifies the channel.
|
||||
|
||||
Returns:
|
||||
pooling output array of rank N+2.
|
||||
@ -110,11 +112,15 @@ def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=r
|
||||
Raises:
|
||||
ValueError: if arguments are invalid.
|
||||
"""
|
||||
if data_format is None or not data_format.startswith("NC"):
|
||||
spatial_start_dim = 1
|
||||
else:
|
||||
spatial_start_dim = 2
|
||||
output = input
|
||||
for i in range(len(window_shape)):
|
||||
output = pool_direct_single_axis(
|
||||
input=output,
|
||||
axis=i + 1,
|
||||
axis=i + spatial_start_dim,
|
||||
window_size=window_shape[i],
|
||||
pooling_type=pooling_type,
|
||||
padding=padding,
|
||||
@ -125,26 +131,13 @@ def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=r
|
||||
|
||||
class PoolingTest(tf.test.TestCase):
|
||||
|
||||
def _test(self, input_shape, window_shape, pooling_type, padding,
|
||||
dilation_rate, strides):
|
||||
def _test(self, input_shape, **kwargs):
|
||||
# Use negative numbers to make sure there isn't any zero padding getting
|
||||
# used.
|
||||
x = -np.arange(
|
||||
np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
|
||||
y1 = pool_direct(
|
||||
input=x,
|
||||
window_shape=window_shape,
|
||||
pooling_type=pooling_type,
|
||||
padding=padding,
|
||||
dilation_rate=dilation_rate,
|
||||
strides=strides)
|
||||
y2 = tf.nn.pool(
|
||||
input=x,
|
||||
window_shape=window_shape,
|
||||
pooling_type=pooling_type,
|
||||
padding=padding,
|
||||
dilation_rate=dilation_rate,
|
||||
strides=strides)
|
||||
y1 = pool_direct(input=x, **kwargs)
|
||||
y2 = tf.nn.pool(input=x, **kwargs)
|
||||
self.assertAllClose(y1, y2.eval(), rtol=1e-2, atol=1e-2)
|
||||
|
||||
def testPoolSimple(self):
|
||||
@ -159,6 +152,32 @@ class PoolingTest(tf.test.TestCase):
|
||||
dilation_rate=[1, 1],
|
||||
strides=[1, 2])
|
||||
|
||||
def testPool1D(self):
|
||||
with self.test_session():
|
||||
for padding in ["SAME", "VALID"]:
|
||||
for pooling_type in ["MAX", "AVG"]:
|
||||
for input_shape in [[2, 9, 2], [2, 10, 2]]:
|
||||
for window_shape in [[1], [2], [3]]:
|
||||
if padding != "SAME":
|
||||
for dilation_rate in [[1], [2], [3]]:
|
||||
self._test(
|
||||
input_shape=input_shape,
|
||||
window_shape=window_shape,
|
||||
padding=padding,
|
||||
pooling_type=pooling_type,
|
||||
dilation_rate=dilation_rate,
|
||||
strides=[1])
|
||||
for strides in [[1], [2], [3]]:
|
||||
if np.any(np.array(strides) > window_shape):
|
||||
continue
|
||||
self._test(
|
||||
input_shape=input_shape,
|
||||
window_shape=window_shape,
|
||||
padding=padding,
|
||||
pooling_type=pooling_type,
|
||||
dilation_rate=[1],
|
||||
strides=strides)
|
||||
|
||||
def testPool2D(self):
|
||||
with self.test_session():
|
||||
for padding in ["SAME", "VALID"]:
|
||||
@ -212,6 +231,40 @@ class PoolingTest(tf.test.TestCase):
|
||||
dilation_rate=[1, 1, 1],
|
||||
strides=strides)
|
||||
|
||||
def testPoolNC(self):
|
||||
if tf.test.is_gpu_available():
|
||||
# "NC*" format is not currently supported on CPU.
|
||||
with self.test_session(use_gpu=True):
|
||||
for padding in ["SAME", "VALID"]:
|
||||
self._test(input_shape=[2, 2, 9],
|
||||
window_shape=[2],
|
||||
padding=padding,
|
||||
pooling_type="MAX",
|
||||
strides=[1],
|
||||
dilation_rate=[1],
|
||||
data_format="NCW")
|
||||
self._test(input_shape=[2, 2, 9],
|
||||
window_shape=[2],
|
||||
padding=padding,
|
||||
pooling_type="MAX",
|
||||
strides=[2],
|
||||
dilation_rate=[1],
|
||||
data_format="NCW")
|
||||
self._test(input_shape=[2, 2, 7, 9],
|
||||
window_shape=[2, 2],
|
||||
padding=padding,
|
||||
pooling_type="MAX",
|
||||
strides=[1, 2],
|
||||
dilation_rate=[1, 1],
|
||||
data_format="NCHW")
|
||||
self._test(input_shape=[2, 2, 7, 9],
|
||||
window_shape=[2, 2],
|
||||
padding="VALID",
|
||||
pooling_type="MAX",
|
||||
strides=[1, 1],
|
||||
dilation_rate=[2, 2],
|
||||
data_format="NCHW")
|
||||
|
||||
def _test_gradient(self, input_shape, **kwargs):
|
||||
x_val = -np.arange(
|
||||
np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
|
||||
@ -224,6 +277,32 @@ class PoolingTest(tf.test.TestCase):
|
||||
err_tolerance = 1e-2
|
||||
self.assertLess(err, err_tolerance)
|
||||
|
||||
def testGradient1D(self):
|
||||
with self.test_session():
|
||||
for padding in ["SAME", "VALID"]:
|
||||
for pooling_type in ["AVG", "MAX"]:
|
||||
for input_shape in [[2, 5, 2], [1, 4, 1]]:
|
||||
for window_shape in [[1], [2]]:
|
||||
if padding != "SAME":
|
||||
for dilation_rate in [[1], [2]]:
|
||||
self._test_gradient(
|
||||
input_shape=input_shape,
|
||||
window_shape=window_shape,
|
||||
padding=padding,
|
||||
pooling_type=pooling_type,
|
||||
dilation_rate=dilation_rate,
|
||||
strides=[1])
|
||||
for strides in [[1], [2]]:
|
||||
if np.any(np.array(strides) > window_shape):
|
||||
continue
|
||||
self._test(
|
||||
input_shape=input_shape,
|
||||
window_shape=window_shape,
|
||||
padding=padding,
|
||||
pooling_type=pooling_type,
|
||||
dilation_rate=[1],
|
||||
strides=strides)
|
||||
|
||||
def testGradient2D(self):
|
||||
with self.test_session():
|
||||
for padding in ["SAME", "VALID"]:
|
||||
|
@ -42,7 +42,8 @@ from tensorflow.python.ops.gen_nn_ops import *
|
||||
local_response_normalization = gen_nn_ops.lrn
|
||||
|
||||
|
||||
def _non_atrous_convolution(input, filter, padding, strides=None, name=None): # pylint: disable=redefined-builtin
|
||||
def _non_atrous_convolution(input, filter, padding, data_format=None, # pylint: disable=redefined-builtin
|
||||
strides=None, name=None):
|
||||
"""Computes sums of N-D convolutions (actually cross correlation).
|
||||
|
||||
It is required that 1 <= N <= 3.
|
||||
@ -51,12 +52,22 @@ def _non_atrous_convolution(input, filter, padding, strides=None, name=None): #
|
||||
extends the interface of this function with a `dilation_rate` parameter.
|
||||
|
||||
Args:
|
||||
|
||||
input: Rank N+2 tensor of type T of shape
|
||||
`[batch_size] + input_spatial_shape + [in_channels]`.
|
||||
`[batch_size] + input_spatial_shape + [in_channels]` if `data_format`
|
||||
does not start with `"NC"`, or
|
||||
`[batch_size, in_channels] + input_spatial_shape` if `data_format` starts
|
||||
with `"NC"`.
|
||||
filter: Rank N+2 tensor of type T of shape
|
||||
`filter_spatial_shape + [in_channels, out_channels]`. Rank of either
|
||||
`input` or `filter` must be known.
|
||||
padding: Padding method to use, must be either "VALID" or "SAME".
|
||||
data_format: A string or None. Specifies whether the channel dimension of
|
||||
the `input` and output is the last dimension (default, or if `data_format`
|
||||
does not start with "NC"), or the second dimension (if `data_format`
|
||||
starts with "NC"). For N=1, the valid values are "NWC" (default) and
|
||||
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
|
||||
N=3, the valid value is "NDHWC".
|
||||
strides: Sequence of N positive integers, defaults to `[1] * N`.
|
||||
name: Name prefix to use.
|
||||
|
||||
@ -89,26 +100,50 @@ def _non_atrous_convolution(input, filter, padding, strides=None, name=None): #
|
||||
raise ValueError("len(strides)=%d, but should be %d" %
|
||||
(len(strides), conv_dims))
|
||||
if conv_dims == 1:
|
||||
return conv1d(value=input,
|
||||
# conv1d uses the 2-d data format names
|
||||
if data_format is None or data_format == "NWC":
|
||||
data_format_2d = "NHWC"
|
||||
elif data_format == "NCW":
|
||||
data_format_2d = "NCHW"
|
||||
else:
|
||||
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
|
||||
return conv1d(
|
||||
value=input,
|
||||
filters=filter,
|
||||
stride=strides[0],
|
||||
padding=padding,
|
||||
data_format=data_format_2d,
|
||||
name=scope)
|
||||
elif conv_dims == 2:
|
||||
return gen_nn_ops.conv2d(input=input,
|
||||
if data_format is None or data_format == "NHWC":
|
||||
data_format = "NHWC"
|
||||
strides = [1] + list(strides) + [1]
|
||||
elif data_format == "NCHW":
|
||||
strides = [1, 1] + list(strides)
|
||||
else:
|
||||
raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
|
||||
return gen_nn_ops.conv2d(
|
||||
input=input,
|
||||
filter=filter,
|
||||
strides=[1] + list(strides) + [1],
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
name=name)
|
||||
elif conv_dims == 3:
|
||||
return gen_nn_ops.conv3d(input=input,
|
||||
if data_format is None or data_format == "NDHWC":
|
||||
strides = [1] + list(strides) + [1]
|
||||
else:
|
||||
raise ValueError("data_format must be \"NDHWC\".")
|
||||
return gen_nn_ops.conv3d(
|
||||
input=input,
|
||||
filter=filter,
|
||||
strides=[1] + list(strides) + [1],
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
name=name)
|
||||
|
||||
|
||||
def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None): # pylint: disable=redefined-builtin
|
||||
def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None, # pylint: disable=redefined-builtin
|
||||
spatial_dims=None):
|
||||
"""Performs `op` on the space-to-batch representation of `input`.
|
||||
|
||||
This has the effect of transforming sliding window operations into the
|
||||
@ -122,19 +157,27 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
|
||||
Otherwise, it returns:
|
||||
|
||||
batch_to_space_nd(
|
||||
op(space_to_batch_nd(input, dilation_rate, paddings),
|
||||
op(space_to_batch_nd(input, adjusted_dilation_rate, adjusted_paddings),
|
||||
num_spatial_dims,
|
||||
"VALID")
|
||||
dilation_rate,
|
||||
crops),
|
||||
adjusted_dilation_rate,
|
||||
adjusted_crops),
|
||||
|
||||
where `paddings` and `crops` are int32 [num_spatial_dims, 2] tensors that
|
||||
depend on the value of `padding`:
|
||||
where:
|
||||
|
||||
adjusted_dilation_rate is an int64 tensor of shape [max(spatial_dims)],
|
||||
adjusted_{paddings,crops} are int64 tensors of shape [max(spatial_dims), 2]
|
||||
|
||||
defined as follows:
|
||||
|
||||
We first define two int64 tensors `paddings` and `crops` of shape
|
||||
`[num_spatial_dims, 2]` based on the value of `padding` and the spatial
|
||||
dimensions of the `input`:
|
||||
|
||||
If `padding = "VALID"`, then:
|
||||
|
||||
paddings, crops = required_space_to_batch_paddings(
|
||||
input_shape[1:num_spatial_dims+1],
|
||||
input_shape[spatial_dims],
|
||||
dilation_rate)
|
||||
|
||||
If `padding = "SAME"`, then:
|
||||
@ -143,10 +186,30 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
|
||||
filter_shape + (filter_shape - 1) * (dilation_rate - 1)
|
||||
|
||||
paddings, crops = required_space_to_batch_paddings(
|
||||
input_shape[1:num_spatial_dims+1],
|
||||
input_shape[spatial_dims],
|
||||
dilation_rate,
|
||||
[(dilated_filter_shape - 1) // 2,
|
||||
dilated_filter_shape - 1 - (dilated_filter_shape - 1) // 2])
|
||||
|
||||
Because `space_to_batch_nd` and `batch_to_space_nd` assume that the spatial
|
||||
dimensions are contiguous starting at the second dimension, but the specified
|
||||
`spatial_dims` may not be, we must adjust `dilation_rate`, `paddings` and
|
||||
`crops` in order to be usable with these operations. For a given dimension,
|
||||
if the block size is 1, and both the starting and ending padding and crop
|
||||
amounts are 0, then space_to_batch_nd effectively leaves that dimension alone,
|
||||
which is what is needed for dimensions not part of `spatial_dims`.
|
||||
Furthermore, `space_to_batch_nd` and `batch_to_space_nd` handle this case
|
||||
efficiently for any number of leading and trailing dimensions.
|
||||
|
||||
For 0 <= i < len(spatial_dims), we assign:
|
||||
|
||||
adjusted_dilation_rate[spatial_dims[i] - 1] = dilation_rate[i]
|
||||
adjusted_paddings[spatial_dims[i] - 1, :] = paddings[i, :]
|
||||
adjusted_crops[spatial_dims[i] - 1, :] = crops[i, :]
|
||||
|
||||
All unassigned values of `adjusted_dilation_rate` default to 1, while all
|
||||
unassigned values of `adjusted_paddings` and `adjusted_crops` default to 0.
|
||||
|
||||
Note in the case that `dilation_rate` is not uniformly 1, specifying "VALID"
|
||||
padding is equivalent to specifying `padding = "SAME"` with a filter_shape of
|
||||
`[1]*N`.
|
||||
@ -189,19 +252,23 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
|
||||
net = with_space_to_batch(net, dilation_rate, "VALID", combined_op)
|
||||
|
||||
Args:
|
||||
input: Tensor of rank >= 1 + num_spatial_dims.
|
||||
input: Tensor of rank > max(spatial_dims).
|
||||
dilation_rate: int32 Tensor of *known* shape [num_spatial_dims].
|
||||
padding: str constant equal to "VALID" or "SAME"
|
||||
op: Function that maps (input, num_spatial_dims, padding) -> output
|
||||
filter_shape: If padding = "SAME", specifies the shape of the convolution
|
||||
kernel/pooling window as an integer Tensor of shape [>=num_spatial_dims].
|
||||
If padding = "VALID", filter_shape is ignored and need not be specified.
|
||||
spatial_dims: Monotonically increasing sequence of `num_spatial_dims`
|
||||
integers (which are >= 1) specifying the spatial dimensions of `input`
|
||||
and output. Defaults to: `range(1, num_spatial_dims+1)`.
|
||||
|
||||
Returns:
|
||||
The output Tensor as described above.
|
||||
|
||||
Raises:
|
||||
ValueError: if padding is invalid or the arguments are incompatible.
|
||||
ValueError: if `padding` is invalid or the arguments are incompatible.
|
||||
ValueError: if `spatial_dims` are invalid.
|
||||
|
||||
"""
|
||||
input = ops.convert_to_tensor(input, name="input")
|
||||
@ -218,18 +285,27 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
|
||||
|
||||
num_spatial_dims = rate_shape[0].value
|
||||
|
||||
if spatial_dims is None:
|
||||
spatial_dims = range(1, num_spatial_dims + 1)
|
||||
orig_spatial_dims = list(spatial_dims)
|
||||
spatial_dims = sorted(set(int(x) for x in orig_spatial_dims))
|
||||
if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
|
||||
raise ValueError(
|
||||
"spatial_dims must be a montonically increasing sequence of positive integers") # pylint: disable=line-too-long
|
||||
last_spatial_dim = spatial_dims[-1]
|
||||
|
||||
try:
|
||||
input.get_shape().with_rank_at_least(num_spatial_dims + 1)
|
||||
input.get_shape().with_rank_at_least(last_spatial_dim + 1)
|
||||
except ValueError:
|
||||
ValueError("input tensor must have rank %d at least" %
|
||||
(num_spatial_dims + 1))
|
||||
(last_spatial_dim + 1))
|
||||
|
||||
const_rate = tensor_util.constant_value(dilation_rate)
|
||||
rate_or_const_rate = dilation_rate
|
||||
if const_rate is not None:
|
||||
rate_or_const_rate = const_rate
|
||||
if np.any(const_rate < 1):
|
||||
raise ValueError("rate must be positive")
|
||||
raise ValueError("dilation_rate must be positive")
|
||||
if np.all(const_rate == 1):
|
||||
return op(input, num_spatial_dims, padding)
|
||||
|
||||
@ -266,26 +342,88 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
|
||||
raise ValueError("Invalid padding method %r" % padding)
|
||||
|
||||
# Handle input whose shape is unknown during graph creation.
|
||||
if input.get_shape().is_fully_defined():
|
||||
input_shape = np.array(input.get_shape().as_list())
|
||||
else:
|
||||
input_shape = array_ops.shape(input)
|
||||
input_spatial_shape = None
|
||||
if input.get_shape().ndims is not None:
|
||||
input_shape_list = input.get_shape().as_list()
|
||||
input_spatial_shape = [input_shape_list[i] for i in spatial_dims]
|
||||
if input_spatial_shape is None or None in input_spatial_shape:
|
||||
input_spatial_shape = array_ops.gather(array_ops.shape(input), spatial_dims)
|
||||
|
||||
input_spatial_shape = input_shape[1:num_spatial_dims+1]
|
||||
paddings, crops = array_ops.required_space_to_batch_paddings(
|
||||
input_shape=input_spatial_shape,
|
||||
base_paddings=base_paddings,
|
||||
block_shape=dilation_rate)
|
||||
|
||||
input_converted = array_ops.space_to_batch_nd(input=input,
|
||||
def adjust(orig, fill_value):
|
||||
"""Returns an `adjusted` version of `orig` based on `spatial_dims`.
|
||||
|
||||
Tensor of the same type as `orig` and with shape
|
||||
`[max(spatial_dims), ...]` where:
|
||||
|
||||
adjusted[spatial_dims[i] - 1, ...] = orig[i, ...]
|
||||
|
||||
for 0 <= i < len(spatial_dims), and
|
||||
|
||||
adjusted[j, ...] = fill_value
|
||||
|
||||
for j != spatial_dims[i] - 1 for some i.
|
||||
|
||||
If `orig` is a constant value, then the result will be a constant value.
|
||||
|
||||
Args:
|
||||
orig: Tensor of rank > max(spatial_dims).
|
||||
fill_value: Numpy scalar (of same data type as `orig) specifying the fill
|
||||
value for non-spatial dimensions.
|
||||
|
||||
Returns:
|
||||
`adjusted` tensor.
|
||||
"""
|
||||
fill_dims = orig.get_shape().as_list()[1:]
|
||||
dtype = orig.dtype.as_numpy_dtype
|
||||
parts = []
|
||||
const_orig = tensor_util.constant_value(orig)
|
||||
const_or_orig = const_orig if const_orig is not None else orig
|
||||
prev_spatial_dim = 0
|
||||
i = 0
|
||||
while i < len(spatial_dims):
|
||||
start_i = i
|
||||
start_spatial_dim = spatial_dims[i]
|
||||
if start_spatial_dim > 1:
|
||||
# Fill in any gap from the previous spatial dimension (or dimension 1 if
|
||||
# this is the first spatial dimension) with `fill_value`.
|
||||
parts.append(
|
||||
np.full(
|
||||
[start_spatial_dim - 1 - prev_spatial_dim] + fill_dims,
|
||||
fill_value,
|
||||
dtype=dtype))
|
||||
# Find the largest value of i such that:
|
||||
# [spatial_dims[start_i], ..., spatial_dims[i]]
|
||||
# == [start_spatial_dim, ..., start_spatial_dim + i - start_i],
|
||||
# i.e. the end of a contiguous group of spatial dimensions.
|
||||
while (i + 1 < len(spatial_dims) and
|
||||
spatial_dims[i + 1] == spatial_dims[i] + 1):
|
||||
i += 1
|
||||
parts.append(const_or_orig[start_i:i + 1])
|
||||
prev_spatial_dim = spatial_dims[i]
|
||||
i += 1
|
||||
if const_orig is not None:
|
||||
return np.concatenate(parts)
|
||||
else:
|
||||
return array_ops.concat(0, parts)
|
||||
|
||||
dilation_rate = adjust(dilation_rate, 1)
|
||||
paddings = adjust(paddings, 0)
|
||||
crops = adjust(crops, 0)
|
||||
|
||||
input_converted = array_ops.space_to_batch_nd(
|
||||
input=input,
|
||||
block_shape=dilation_rate,
|
||||
paddings=paddings)
|
||||
|
||||
result = op(input_converted, num_spatial_dims, "VALID")
|
||||
|
||||
result_converted = array_ops.batch_to_space_nd(input=result,
|
||||
block_shape=dilation_rate,
|
||||
crops=crops)
|
||||
result_converted = array_ops.batch_to_space_nd(
|
||||
input=result, block_shape=dilation_rate, crops=crops)
|
||||
return result_converted
|
||||
|
||||
|
||||
@ -333,7 +471,8 @@ def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate):
|
||||
|
||||
|
||||
def convolution(input, filter, # pylint: disable=redefined-builtin
|
||||
padding, strides=None, dilation_rate=None, name=None):
|
||||
padding, strides=None, dilation_rate=None,
|
||||
name=None, data_format=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Computes sums of N-D convolutions (actually cross-correlation).
|
||||
|
||||
@ -343,7 +482,8 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
|
||||
the optional `dilation_rate` parameter. Currently, however, output striding
|
||||
is not supported for atrous convolutions.
|
||||
|
||||
Specifically, given rank (N+2) `input` Tensor of shape
|
||||
Specifically, in the case that `data_format` does not start with "NC", given
|
||||
a rank (N+2) `input` Tensor of shape
|
||||
|
||||
[num_batches,
|
||||
input_spatial_shape[0],
|
||||
@ -368,23 +508,34 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
|
||||
|
||||
sum_{z[0], ..., z[N-1], q}
|
||||
|
||||
filters[z[0], ..., z[N-1], q, k] *
|
||||
filter[z[0], ..., z[N-1], q, k] *
|
||||
padded_input[b,
|
||||
x[0]*strides[0] + dilation_rate[0]*z[0],
|
||||
...,
|
||||
x[N-1]*strides[N-1] + dilation_rate[N-1]*z[N-1],
|
||||
q],
|
||||
q]
|
||||
|
||||
where `padded_input` is obtained by zero padding the input using an effective
|
||||
spatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` and
|
||||
output striding `strides` as described in the
|
||||
[comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution).
|
||||
|
||||
In the case that `data_format` does start with `"NC"`, the `input` and output
|
||||
(but not the `filter`) are simply transposed as follows:
|
||||
|
||||
convolution(input, data_format, **kwargs) =
|
||||
tf.transpose(convolution(tf.transpose(input, [0] + range(2,N+2) + [1]),
|
||||
**kwargs),
|
||||
[0, N+1] + range(1, N+1))
|
||||
|
||||
It is required that 1 <= N <= 3.
|
||||
|
||||
Args:
|
||||
input: An N-D `Tensor` of type `T`, of shape
|
||||
`[batch_size] + input_spatial_shape + [in_channels]`.
|
||||
`[batch_size] + input_spatial_shape + [in_channels]` if data_format does
|
||||
not start with "NC" (default), or
|
||||
`[batch_size, in_channels] + input_spatial_shape` if data_format starts
|
||||
with "NC".
|
||||
filter: An N-D `Tensor` with the same type as `input` and shape
|
||||
`spatial_filter_shape + [in_channels, out_channels]`.
|
||||
padding: A string, either `"VALID"` or `"SAME"`. The padding algorithm.
|
||||
@ -400,12 +551,23 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
|
||||
filter in each spatial dimension i. If any value of dilation_rate is > 1,
|
||||
then all values of strides must be 1.
|
||||
name: Optional name for the returned tensor.
|
||||
data_format: A string or None. Specifies whether the channel dimension of
|
||||
the `input` and output is the last dimension (default, or if `data_format`
|
||||
does not start with "NC"), or the second dimension (if `data_format`
|
||||
starts with "NC"). For N=1, the valid values are "NWC" (default) and
|
||||
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
|
||||
N=3, the valid value is "NDHWC".
|
||||
|
||||
Returns:
|
||||
A `Tensor` with the same type as `value` of shape
|
||||
A `Tensor` with the same type as `input` of shape
|
||||
|
||||
`[batch_size] + output_spatial_shape + [out_channels]`,
|
||||
`[batch_size] + output_spatial_shape + [out_channels]`
|
||||
|
||||
if data_format is None or does not start with "NC", or
|
||||
|
||||
`[batch_size, out_channels] + output_spatial_shape`
|
||||
|
||||
if data_format starts with "NC",
|
||||
where `output_spatial_shape` depends on the value of `padding`.
|
||||
|
||||
If padding == "SAME":
|
||||
@ -418,8 +580,8 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
|
||||
/ strides[i]).
|
||||
|
||||
Raises:
|
||||
ValueError: If input/output depth does not match `filter` shape, or if
|
||||
padding is other than `"VALID"` or `"SAME"`.
|
||||
ValueError: If input/output depth does not match `filter` shape, if padding
|
||||
is other than `"VALID"` or `"SAME"`, or if data_format is invalid.
|
||||
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
@ -444,12 +606,19 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
|
||||
except ValueError:
|
||||
ValueError("filter tensor must have rank %d" % (num_spatial_dims + 2))
|
||||
|
||||
if not input.get_shape()[num_spatial_dims + 1].is_compatible_with(
|
||||
filter.get_shape()[num_spatial_dims]):
|
||||
if data_format is None or not data_format.startswith("NC"):
|
||||
input_channels_dim = input.get_shape()[num_spatial_dims + 1]
|
||||
spatial_dims = range(1, num_spatial_dims+1)
|
||||
else:
|
||||
input_channels_dim = input.get_shape()[1]
|
||||
spatial_dims = range(2, num_spatial_dims+2)
|
||||
|
||||
if not input_channels_dim.is_compatible_with(filter.get_shape()[
|
||||
num_spatial_dims]):
|
||||
raise ValueError(
|
||||
"number of input channels does not match corresponding dimension of filter, "
|
||||
"{} != {}".format(input.get_shape()[num_spatial_dims + 1],
|
||||
filter.get_shape()[num_spatial_dims]))
|
||||
"{} != {}".format(input_channels_dim, filter.get_shape()[
|
||||
num_spatial_dims]))
|
||||
|
||||
strides, dilation_rate = _get_strides_and_dilation_rate(
|
||||
num_spatial_dims, strides, dilation_rate)
|
||||
@ -459,12 +628,14 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
|
||||
input=input_converted,
|
||||
filter=filter,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
strides=strides,
|
||||
name=name)
|
||||
|
||||
return with_space_to_batch(
|
||||
input=input,
|
||||
filter_shape=array_ops.shape(filter),
|
||||
spatial_dims=spatial_dims,
|
||||
dilation_rate=dilation_rate,
|
||||
padding=padding,
|
||||
op=op)
|
||||
@ -476,11 +647,12 @@ def pool(input, # pylint: disable=redefined-builtin
|
||||
padding,
|
||||
dilation_rate=None,
|
||||
strides=None,
|
||||
name=None):
|
||||
name=None,
|
||||
data_format=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Performs an N-D pooling operation.
|
||||
|
||||
Computes for
|
||||
In the case that `data_format` does not start with "NC", computes for
|
||||
0 <= b < batch_size,
|
||||
0 <= x[i] < output_spatial_shape[i],
|
||||
0 <= c < num_channels:
|
||||
@ -498,10 +670,20 @@ def pool(input, # pylint: disable=redefined-builtin
|
||||
[comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution).
|
||||
The reduction never includes out-of-bounds positions.
|
||||
|
||||
In the case that `data_format` starts with `"NC"`, the `input` and output are
|
||||
simply transposed as follows:
|
||||
|
||||
pool(input, data_format, **kwargs) =
|
||||
tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
|
||||
**kwargs),
|
||||
[0, N+1] + range(1, N+1))
|
||||
|
||||
Args:
|
||||
input: Tensor of rank N+2, of shape
|
||||
[batch_size] + input_spatial_shape + [num_channels].
|
||||
Pooling happens over the spatial dimensions only.
|
||||
`[batch_size] + input_spatial_shape + [num_channels]` if data_format does
|
||||
not start with "NC" (default), or
|
||||
`[batch_size, num_channels] + input_spatial_shape` if data_format starts
|
||||
with "NC". Pooling happens over the spatial dimensions only.
|
||||
window_shape: Sequence of N ints >= 1.
|
||||
pooling_type: Specifies pooling operation, must be "AVG" or "MAX".
|
||||
padding: The padding algorithm, must be "SAME" or "VALID".
|
||||
@ -513,10 +695,22 @@ def pool(input, # pylint: disable=redefined-builtin
|
||||
If any value of strides is > 1, then all values of dilation_rate must be
|
||||
1.
|
||||
name: Optional. Name of the op.
|
||||
data_format: A string or None. Specifies whether the channel dimension of
|
||||
the `input` and output is the last dimension (default, or if `data_format`
|
||||
does not start with "NC"), or the second dimension (if `data_format`
|
||||
starts with "NC"). For N=1, the valid values are "NWC" (default) and
|
||||
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
|
||||
N=3, the valid value is "NDHWC".
|
||||
|
||||
Returns:
|
||||
Tensor of rank N+2, of shape
|
||||
[batch_size] + output_spatial_shape + [num_channels],
|
||||
[batch_size] + output_spatial_shape + [num_channels]
|
||||
|
||||
if data_format is None or does not start with "NC", or
|
||||
|
||||
[batch_size, num_channels] + output_spatial_shape
|
||||
|
||||
if data_format starts with "NC",
|
||||
where `output_spatial_shape` depends on the value of padding:
|
||||
|
||||
If padding = "SAME":
|
||||
@ -536,8 +730,8 @@ def pool(input, # pylint: disable=redefined-builtin
|
||||
input = ops.convert_to_tensor(input, name="input")
|
||||
|
||||
num_spatial_dims = len(window_shape)
|
||||
if num_spatial_dims < 2 or num_spatial_dims > 3:
|
||||
raise ValueError("It is required that 2 <= num_spatial_dims <= 3.")
|
||||
if num_spatial_dims < 1 or num_spatial_dims > 3:
|
||||
raise ValueError("It is required that 1 <= num_spatial_dims <= 3.")
|
||||
|
||||
input.get_shape().with_rank(num_spatial_dims + 2)
|
||||
|
||||
@ -553,8 +747,10 @@ def pool(input, # pylint: disable=redefined-builtin
|
||||
"strides > window_shape not supported due to inconsistency between "
|
||||
"CPU and GPU implementations")
|
||||
|
||||
pooling_ops = {("MAX", 2): max_pool,
|
||||
pooling_ops = {("MAX", 1): max_pool,
|
||||
("MAX", 2): max_pool,
|
||||
("MAX", 3): max_pool3d, # pylint: disable=undefined-variable
|
||||
("AVG", 1): avg_pool,
|
||||
("AVG", 2): avg_pool,
|
||||
("AVG", 3): avg_pool3d, # pylint: disable=undefined-variable
|
||||
}
|
||||
@ -563,17 +759,51 @@ def pool(input, # pylint: disable=redefined-builtin
|
||||
raise ValueError("%d-D %s pooling is not supported." %
|
||||
(op_key[1], op_key[0]))
|
||||
|
||||
def op(converted_input, _, converted_padding):
|
||||
return pooling_ops[op_key](converted_input,
|
||||
[1] + list(window_shape) + [1],
|
||||
[1] + list(strides) + [1],
|
||||
converted_padding,
|
||||
name=scope)
|
||||
if data_format is None or not data_format.startswith("NC"):
|
||||
adjusted_window_shape = [1] + list(window_shape) + [1]
|
||||
adjusted_strides = [1] + list(strides) + [1]
|
||||
spatial_dims = range(1, num_spatial_dims + 1)
|
||||
else:
|
||||
adjusted_window_shape = [1, 1] + list(window_shape)
|
||||
adjusted_strides = [1, 1] + list(strides)
|
||||
spatial_dims = range(2, num_spatial_dims + 2)
|
||||
|
||||
return with_space_to_batch(input=input,
|
||||
if num_spatial_dims == 3:
|
||||
if data_format is not None and data_format != "NDHWC":
|
||||
raise ValueError("data_format must be \"NDHWC\".")
|
||||
data_format_kwargs = dict()
|
||||
elif num_spatial_dims == 1:
|
||||
if data_format is None or data_format == "NWC":
|
||||
data_format_kwargs = dict(data_format="NHWC")
|
||||
elif data_format == "NCW":
|
||||
data_format_kwargs = dict(data_format="NCHW")
|
||||
else:
|
||||
raise ValueError("data_format must be either \"NWC\" or \"NCW\".")
|
||||
adjusted_window_shape = [1] + adjusted_window_shape
|
||||
adjusted_strides = [1] + adjusted_strides
|
||||
else:
|
||||
data_format_kwargs = dict(data_format=data_format)
|
||||
|
||||
def op(converted_input, _, converted_padding): # pylint: disable=missing-docstring
|
||||
if num_spatial_dims == 1:
|
||||
converted_input = array_ops.expand_dims(converted_input,
|
||||
spatial_dims[0])
|
||||
result = pooling_ops[op_key](converted_input,
|
||||
adjusted_window_shape,
|
||||
adjusted_strides,
|
||||
converted_padding,
|
||||
name=scope,
|
||||
**data_format_kwargs)
|
||||
if num_spatial_dims == 1:
|
||||
result = array_ops.squeeze(result, [spatial_dims[0]])
|
||||
return result
|
||||
|
||||
return with_space_to_batch(
|
||||
input=input,
|
||||
dilation_rate=dilation_rate,
|
||||
padding=padding,
|
||||
op=op,
|
||||
spatial_dims=spatial_dims,
|
||||
filter_shape=window_shape)
|
||||
|
||||
|
||||
@ -1794,19 +2024,27 @@ def conv1d(value, filters, stride, padding,
|
||||
name=None):
|
||||
"""Computes a 1-D convolution given 3-D input and filter tensors.
|
||||
|
||||
Given an input tensor of shape [batch, in_width, in_channels]
|
||||
Given an input tensor of shape
|
||||
[batch, in_width, in_channels]
|
||||
if data_format is "NHWC", or
|
||||
[batch, in_channels, in_width]
|
||||
if data_format is "NCHW",
|
||||
and a filter / kernel tensor of shape
|
||||
[filter_width, in_channels, out_channels], this op reshapes
|
||||
the arguments to pass them to conv2d to perform the equivalent
|
||||
convolution operation.
|
||||
|
||||
Internally, this op reshapes the input tensors and invokes
|
||||
`tf.nn.conv2d`. A tensor of shape [batch, in_width, in_channels]
|
||||
is reshaped to [batch, 1, in_width, in_channels], and the filter
|
||||
is reshaped to [1, filter_width, in_channels, out_channels].
|
||||
The result is then reshaped back to [batch, out_width, out_channels]
|
||||
(where out_width is a function of the stride and padding as in
|
||||
conv2d) and returned to the caller.
|
||||
Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
|
||||
For example, if `data_format` does not start with "NC", a tensor of shape
|
||||
[batch, in_width, in_channels]
|
||||
is reshaped to
|
||||
[batch, 1, in_width, in_channels],
|
||||
and the filter is reshaped to
|
||||
[1, filter_width, in_channels, out_channels].
|
||||
The result is then reshaped back to
|
||||
[batch, out_width, out_channels]
|
||||
(where out_width is a function of the stride and padding as in conv2d) and
|
||||
returned to the caller.
|
||||
|
||||
Args:
|
||||
value: A 3D `Tensor`. Must be of type `float32` or `float64`.
|
||||
@ -1823,16 +2061,27 @@ def conv1d(value, filters, stride, padding,
|
||||
|
||||
Returns:
|
||||
A `Tensor`. Has the same type as input.
|
||||
|
||||
Raises:
|
||||
ValueError: if `data_format` is invalid.
|
||||
"""
|
||||
with ops.name_scope(name, "conv1d", [value, filters]) as name:
|
||||
# Reshape the input tensor to [batch, 1, in_width, in_channels]
|
||||
value = array_ops.expand_dims(value, 1)
|
||||
# And reshape the filter to [1, filter_width, in_channels, out_channels]
|
||||
if data_format is None or data_format == "NHWC":
|
||||
data_format = "NHWC"
|
||||
spatial_start_dim = 1
|
||||
strides = [1, 1, stride, 1]
|
||||
elif data_format == "NCHW":
|
||||
spatial_start_dim = 2
|
||||
strides = [1, 1, 1, stride]
|
||||
else:
|
||||
raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
|
||||
value = array_ops.expand_dims(value, spatial_start_dim)
|
||||
filters = array_ops.expand_dims(filters, 0)
|
||||
result = gen_nn_ops.conv2d(value, filters, [1, 1, stride, 1], padding,
|
||||
result = gen_nn_ops.conv2d(value, filters, strides, padding,
|
||||
use_cudnn_on_gpu=use_cudnn_on_gpu,
|
||||
data_format=data_format)
|
||||
return array_ops.squeeze(result, [1])
|
||||
return array_ops.squeeze(result, [spatial_start_dim])
|
||||
|
||||
|
||||
ops.RegisterShape("Dilation2D")(common_shapes.call_cpp_shape_fn)
|
||||
|
Loading…
Reference in New Issue
Block a user