Add data_format support for N-D convolution and pooling.

This adds support for "NC*" data layouts for N-D convolution and
pooling (including atrous convolution and pooling); previously, only
"N*C" data lyaouts were supported.

This also adds support for 1-D pooling (by forwarding to the 2-D
implementation), and fixes the broken data_format support in
conv1d.
Change: 136556507
This commit is contained in:
A. Unique TensorFlower 2016-10-18 20:28:41 -08:00 committed by TensorFlower Gardener
parent 45010d6a49
commit 8fa9b949dc
3 changed files with 453 additions and 107 deletions

View File

@ -48,15 +48,15 @@ def upsample_filters(filters, rate):
class AtrousConvolutionTest(tf.test.TestCase): class AtrousConvolutionTest(tf.test.TestCase):
def _test_atrous_convolution(self, input_shape, filter_shape, padding, def _test_atrous_convolution(self, input_shape, filter_shape, dilation_rate,
dilation_rate): **kwargs):
filters = np.arange( filters = np.arange(
np.prod(filter_shape), dtype=np.float32).reshape(filter_shape) np.prod(filter_shape), dtype=np.float32).reshape(filter_shape)
filters_upsampled = upsample_filters(filters, dilation_rate) filters_upsampled = upsample_filters(filters, dilation_rate)
x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape)
y1 = tf.nn.convolution( y1 = tf.nn.convolution(
input=x, filter=filters, padding=padding, dilation_rate=dilation_rate) input=x, filter=filters, dilation_rate=dilation_rate, **kwargs)
y2 = tf.nn.convolution(input=x, filter=filters_upsampled, padding=padding) y2 = tf.nn.convolution(input=x, filter=filters_upsampled, **kwargs)
self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-2, atol=1e-2) self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-2, atol=1e-2)
def testAtrousConvolution2D(self): def testAtrousConvolution2D(self):
@ -99,6 +99,24 @@ class AtrousConvolutionTest(tf.test.TestCase):
padding=padding, padding=padding,
dilation_rate=[rate]) dilation_rate=[rate])
def testAtrousConvolutionNC(self):
if tf.test.is_gpu_available():
# "NCW" and "NCHW" formats are not currently supported on CPU.
with self.test_session(use_gpu=True):
for padding in ["SAME", "VALID"]:
self._test_atrous_convolution(
input_shape=[2, 2, 9],
padding=padding,
filter_shape=[3, 2, 2],
dilation_rate=[2],
data_format="NCW")
self._test_atrous_convolution(
input_shape=[2, 2, 9, 5],
padding=padding,
filter_shape=[3, 3, 2, 2],
dilation_rate=[2, 1],
data_format="NCHW")
def testAtrousSequence(self): def testAtrousSequence(self):
"""Tests optimization of sequence of atrous convolutions. """Tests optimization of sequence of atrous convolutions.

View File

@ -89,7 +89,7 @@ def pool_direct_single_axis(input, # pylint: disable=redefined-builtin
def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=redefined-builtin def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=redefined-builtin
dilation_rate, strides): dilation_rate, strides, data_format=None):
"""Numpy implementation of pooling. """Numpy implementation of pooling.
This is intended for testing only, and therefore isn't particularly efficient. This is intended for testing only, and therefore isn't particularly efficient.
@ -103,6 +103,8 @@ def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=r
padding: either "SAME" or "VALID". padding: either "SAME" or "VALID".
dilation_rate: Sequence of N ints >= 1. dilation_rate: Sequence of N ints >= 1.
strides: Sequence of N ints >= 1. strides: Sequence of N ints >= 1.
data_format: If specified and starts with "NC", indicates that second
dimension, rather than the last dimension, specifies the channel.
Returns: Returns:
pooling output array of rank N+2. pooling output array of rank N+2.
@ -110,11 +112,15 @@ def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=r
Raises: Raises:
ValueError: if arguments are invalid. ValueError: if arguments are invalid.
""" """
if data_format is None or not data_format.startswith("NC"):
spatial_start_dim = 1
else:
spatial_start_dim = 2
output = input output = input
for i in range(len(window_shape)): for i in range(len(window_shape)):
output = pool_direct_single_axis( output = pool_direct_single_axis(
input=output, input=output,
axis=i + 1, axis=i + spatial_start_dim,
window_size=window_shape[i], window_size=window_shape[i],
pooling_type=pooling_type, pooling_type=pooling_type,
padding=padding, padding=padding,
@ -125,26 +131,13 @@ def pool_direct(input, window_shape, pooling_type, padding, # pylint: disable=r
class PoolingTest(tf.test.TestCase): class PoolingTest(tf.test.TestCase):
def _test(self, input_shape, window_shape, pooling_type, padding, def _test(self, input_shape, **kwargs):
dilation_rate, strides):
# Use negative numbers to make sure there isn't any zero padding getting # Use negative numbers to make sure there isn't any zero padding getting
# used. # used.
x = -np.arange( x = -np.arange(
np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1 np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
y1 = pool_direct( y1 = pool_direct(input=x, **kwargs)
input=x, y2 = tf.nn.pool(input=x, **kwargs)
window_shape=window_shape,
pooling_type=pooling_type,
padding=padding,
dilation_rate=dilation_rate,
strides=strides)
y2 = tf.nn.pool(
input=x,
window_shape=window_shape,
pooling_type=pooling_type,
padding=padding,
dilation_rate=dilation_rate,
strides=strides)
self.assertAllClose(y1, y2.eval(), rtol=1e-2, atol=1e-2) self.assertAllClose(y1, y2.eval(), rtol=1e-2, atol=1e-2)
def testPoolSimple(self): def testPoolSimple(self):
@ -159,6 +152,32 @@ class PoolingTest(tf.test.TestCase):
dilation_rate=[1, 1], dilation_rate=[1, 1],
strides=[1, 2]) strides=[1, 2])
def testPool1D(self):
with self.test_session():
for padding in ["SAME", "VALID"]:
for pooling_type in ["MAX", "AVG"]:
for input_shape in [[2, 9, 2], [2, 10, 2]]:
for window_shape in [[1], [2], [3]]:
if padding != "SAME":
for dilation_rate in [[1], [2], [3]]:
self._test(
input_shape=input_shape,
window_shape=window_shape,
padding=padding,
pooling_type=pooling_type,
dilation_rate=dilation_rate,
strides=[1])
for strides in [[1], [2], [3]]:
if np.any(np.array(strides) > window_shape):
continue
self._test(
input_shape=input_shape,
window_shape=window_shape,
padding=padding,
pooling_type=pooling_type,
dilation_rate=[1],
strides=strides)
def testPool2D(self): def testPool2D(self):
with self.test_session(): with self.test_session():
for padding in ["SAME", "VALID"]: for padding in ["SAME", "VALID"]:
@ -212,6 +231,40 @@ class PoolingTest(tf.test.TestCase):
dilation_rate=[1, 1, 1], dilation_rate=[1, 1, 1],
strides=strides) strides=strides)
def testPoolNC(self):
if tf.test.is_gpu_available():
# "NC*" format is not currently supported on CPU.
with self.test_session(use_gpu=True):
for padding in ["SAME", "VALID"]:
self._test(input_shape=[2, 2, 9],
window_shape=[2],
padding=padding,
pooling_type="MAX",
strides=[1],
dilation_rate=[1],
data_format="NCW")
self._test(input_shape=[2, 2, 9],
window_shape=[2],
padding=padding,
pooling_type="MAX",
strides=[2],
dilation_rate=[1],
data_format="NCW")
self._test(input_shape=[2, 2, 7, 9],
window_shape=[2, 2],
padding=padding,
pooling_type="MAX",
strides=[1, 2],
dilation_rate=[1, 1],
data_format="NCHW")
self._test(input_shape=[2, 2, 7, 9],
window_shape=[2, 2],
padding="VALID",
pooling_type="MAX",
strides=[1, 1],
dilation_rate=[2, 2],
data_format="NCHW")
def _test_gradient(self, input_shape, **kwargs): def _test_gradient(self, input_shape, **kwargs):
x_val = -np.arange( x_val = -np.arange(
np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1 np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
@ -224,6 +277,32 @@ class PoolingTest(tf.test.TestCase):
err_tolerance = 1e-2 err_tolerance = 1e-2
self.assertLess(err, err_tolerance) self.assertLess(err, err_tolerance)
def testGradient1D(self):
with self.test_session():
for padding in ["SAME", "VALID"]:
for pooling_type in ["AVG", "MAX"]:
for input_shape in [[2, 5, 2], [1, 4, 1]]:
for window_shape in [[1], [2]]:
if padding != "SAME":
for dilation_rate in [[1], [2]]:
self._test_gradient(
input_shape=input_shape,
window_shape=window_shape,
padding=padding,
pooling_type=pooling_type,
dilation_rate=dilation_rate,
strides=[1])
for strides in [[1], [2]]:
if np.any(np.array(strides) > window_shape):
continue
self._test(
input_shape=input_shape,
window_shape=window_shape,
padding=padding,
pooling_type=pooling_type,
dilation_rate=[1],
strides=strides)
def testGradient2D(self): def testGradient2D(self):
with self.test_session(): with self.test_session():
for padding in ["SAME", "VALID"]: for padding in ["SAME", "VALID"]:

View File

@ -42,7 +42,8 @@ from tensorflow.python.ops.gen_nn_ops import *
local_response_normalization = gen_nn_ops.lrn local_response_normalization = gen_nn_ops.lrn
def _non_atrous_convolution(input, filter, padding, strides=None, name=None): # pylint: disable=redefined-builtin def _non_atrous_convolution(input, filter, padding, data_format=None, # pylint: disable=redefined-builtin
strides=None, name=None):
"""Computes sums of N-D convolutions (actually cross correlation). """Computes sums of N-D convolutions (actually cross correlation).
It is required that 1 <= N <= 3. It is required that 1 <= N <= 3.
@ -51,12 +52,22 @@ def _non_atrous_convolution(input, filter, padding, strides=None, name=None): #
extends the interface of this function with a `dilation_rate` parameter. extends the interface of this function with a `dilation_rate` parameter.
Args: Args:
input: Rank N+2 tensor of type T of shape input: Rank N+2 tensor of type T of shape
`[batch_size] + input_spatial_shape + [in_channels]`. `[batch_size] + input_spatial_shape + [in_channels]` if `data_format`
does not start with `"NC"`, or
`[batch_size, in_channels] + input_spatial_shape` if `data_format` starts
with `"NC"`.
filter: Rank N+2 tensor of type T of shape filter: Rank N+2 tensor of type T of shape
`filter_spatial_shape + [in_channels, out_channels]`. Rank of either `filter_spatial_shape + [in_channels, out_channels]`. Rank of either
`input` or `filter` must be known. `input` or `filter` must be known.
padding: Padding method to use, must be either "VALID" or "SAME". padding: Padding method to use, must be either "VALID" or "SAME".
data_format: A string or None. Specifies whether the channel dimension of
the `input` and output is the last dimension (default, or if `data_format`
does not start with "NC"), or the second dimension (if `data_format`
starts with "NC"). For N=1, the valid values are "NWC" (default) and
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
N=3, the valid value is "NDHWC".
strides: Sequence of N positive integers, defaults to `[1] * N`. strides: Sequence of N positive integers, defaults to `[1] * N`.
name: Name prefix to use. name: Name prefix to use.
@ -89,26 +100,50 @@ def _non_atrous_convolution(input, filter, padding, strides=None, name=None): #
raise ValueError("len(strides)=%d, but should be %d" % raise ValueError("len(strides)=%d, but should be %d" %
(len(strides), conv_dims)) (len(strides), conv_dims))
if conv_dims == 1: if conv_dims == 1:
return conv1d(value=input, # conv1d uses the 2-d data format names
filters=filter, if data_format is None or data_format == "NWC":
stride=strides[0], data_format_2d = "NHWC"
padding=padding, elif data_format == "NCW":
name=scope) data_format_2d = "NCHW"
else:
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
return conv1d(
value=input,
filters=filter,
stride=strides[0],
padding=padding,
data_format=data_format_2d,
name=scope)
elif conv_dims == 2: elif conv_dims == 2:
return gen_nn_ops.conv2d(input=input, if data_format is None or data_format == "NHWC":
filter=filter, data_format = "NHWC"
strides=[1] + list(strides) + [1], strides = [1] + list(strides) + [1]
padding=padding, elif data_format == "NCHW":
name=name) strides = [1, 1] + list(strides)
else:
raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
return gen_nn_ops.conv2d(
input=input,
filter=filter,
strides=strides,
padding=padding,
data_format=data_format,
name=name)
elif conv_dims == 3: elif conv_dims == 3:
return gen_nn_ops.conv3d(input=input, if data_format is None or data_format == "NDHWC":
filter=filter, strides = [1] + list(strides) + [1]
strides=[1] + list(strides) + [1], else:
padding=padding, raise ValueError("data_format must be \"NDHWC\".")
name=name) return gen_nn_ops.conv3d(
input=input,
filter=filter,
strides=strides,
padding=padding,
name=name)
def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None): # pylint: disable=redefined-builtin def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None, # pylint: disable=redefined-builtin
spatial_dims=None):
"""Performs `op` on the space-to-batch representation of `input`. """Performs `op` on the space-to-batch representation of `input`.
This has the effect of transforming sliding window operations into the This has the effect of transforming sliding window operations into the
@ -122,19 +157,27 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
Otherwise, it returns: Otherwise, it returns:
batch_to_space_nd( batch_to_space_nd(
op(space_to_batch_nd(input, dilation_rate, paddings), op(space_to_batch_nd(input, adjusted_dilation_rate, adjusted_paddings),
num_spatial_dims, num_spatial_dims,
"VALID") "VALID")
dilation_rate, adjusted_dilation_rate,
crops), adjusted_crops),
where `paddings` and `crops` are int32 [num_spatial_dims, 2] tensors that where:
depend on the value of `padding`:
adjusted_dilation_rate is an int64 tensor of shape [max(spatial_dims)],
adjusted_{paddings,crops} are int64 tensors of shape [max(spatial_dims), 2]
defined as follows:
We first define two int64 tensors `paddings` and `crops` of shape
`[num_spatial_dims, 2]` based on the value of `padding` and the spatial
dimensions of the `input`:
If `padding = "VALID"`, then: If `padding = "VALID"`, then:
paddings, crops = required_space_to_batch_paddings( paddings, crops = required_space_to_batch_paddings(
input_shape[1:num_spatial_dims+1], input_shape[spatial_dims],
dilation_rate) dilation_rate)
If `padding = "SAME"`, then: If `padding = "SAME"`, then:
@ -143,10 +186,30 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
filter_shape + (filter_shape - 1) * (dilation_rate - 1) filter_shape + (filter_shape - 1) * (dilation_rate - 1)
paddings, crops = required_space_to_batch_paddings( paddings, crops = required_space_to_batch_paddings(
input_shape[1:num_spatial_dims+1], input_shape[spatial_dims],
dilation_rate,
[(dilated_filter_shape - 1) // 2, [(dilated_filter_shape - 1) // 2,
dilated_filter_shape - 1 - (dilated_filter_shape - 1) // 2]) dilated_filter_shape - 1 - (dilated_filter_shape - 1) // 2])
Because `space_to_batch_nd` and `batch_to_space_nd` assume that the spatial
dimensions are contiguous starting at the second dimension, but the specified
`spatial_dims` may not be, we must adjust `dilation_rate`, `paddings` and
`crops` in order to be usable with these operations. For a given dimension,
if the block size is 1, and both the starting and ending padding and crop
amounts are 0, then space_to_batch_nd effectively leaves that dimension alone,
which is what is needed for dimensions not part of `spatial_dims`.
Furthermore, `space_to_batch_nd` and `batch_to_space_nd` handle this case
efficiently for any number of leading and trailing dimensions.
For 0 <= i < len(spatial_dims), we assign:
adjusted_dilation_rate[spatial_dims[i] - 1] = dilation_rate[i]
adjusted_paddings[spatial_dims[i] - 1, :] = paddings[i, :]
adjusted_crops[spatial_dims[i] - 1, :] = crops[i, :]
All unassigned values of `adjusted_dilation_rate` default to 1, while all
unassigned values of `adjusted_paddings` and `adjusted_crops` default to 0.
Note in the case that `dilation_rate` is not uniformly 1, specifying "VALID" Note in the case that `dilation_rate` is not uniformly 1, specifying "VALID"
padding is equivalent to specifying `padding = "SAME"` with a filter_shape of padding is equivalent to specifying `padding = "SAME"` with a filter_shape of
`[1]*N`. `[1]*N`.
@ -189,19 +252,23 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
net = with_space_to_batch(net, dilation_rate, "VALID", combined_op) net = with_space_to_batch(net, dilation_rate, "VALID", combined_op)
Args: Args:
input: Tensor of rank >= 1 + num_spatial_dims. input: Tensor of rank > max(spatial_dims).
dilation_rate: int32 Tensor of *known* shape [num_spatial_dims]. dilation_rate: int32 Tensor of *known* shape [num_spatial_dims].
padding: str constant equal to "VALID" or "SAME" padding: str constant equal to "VALID" or "SAME"
op: Function that maps (input, num_spatial_dims, padding) -> output op: Function that maps (input, num_spatial_dims, padding) -> output
filter_shape: If padding = "SAME", specifies the shape of the convolution filter_shape: If padding = "SAME", specifies the shape of the convolution
kernel/pooling window as an integer Tensor of shape [>=num_spatial_dims]. kernel/pooling window as an integer Tensor of shape [>=num_spatial_dims].
If padding = "VALID", filter_shape is ignored and need not be specified. If padding = "VALID", filter_shape is ignored and need not be specified.
spatial_dims: Monotonically increasing sequence of `num_spatial_dims`
integers (which are >= 1) specifying the spatial dimensions of `input`
and output. Defaults to: `range(1, num_spatial_dims+1)`.
Returns: Returns:
The output Tensor as described above. The output Tensor as described above.
Raises: Raises:
ValueError: if padding is invalid or the arguments are incompatible. ValueError: if `padding` is invalid or the arguments are incompatible.
ValueError: if `spatial_dims` are invalid.
""" """
input = ops.convert_to_tensor(input, name="input") input = ops.convert_to_tensor(input, name="input")
@ -218,18 +285,27 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
num_spatial_dims = rate_shape[0].value num_spatial_dims = rate_shape[0].value
if spatial_dims is None:
spatial_dims = range(1, num_spatial_dims + 1)
orig_spatial_dims = list(spatial_dims)
spatial_dims = sorted(set(int(x) for x in orig_spatial_dims))
if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
raise ValueError(
"spatial_dims must be a montonically increasing sequence of positive integers") # pylint: disable=line-too-long
last_spatial_dim = spatial_dims[-1]
try: try:
input.get_shape().with_rank_at_least(num_spatial_dims + 1) input.get_shape().with_rank_at_least(last_spatial_dim + 1)
except ValueError: except ValueError:
ValueError("input tensor must have rank %d at least" % ValueError("input tensor must have rank %d at least" %
(num_spatial_dims + 1)) (last_spatial_dim + 1))
const_rate = tensor_util.constant_value(dilation_rate) const_rate = tensor_util.constant_value(dilation_rate)
rate_or_const_rate = dilation_rate rate_or_const_rate = dilation_rate
if const_rate is not None: if const_rate is not None:
rate_or_const_rate = const_rate rate_or_const_rate = const_rate
if np.any(const_rate < 1): if np.any(const_rate < 1):
raise ValueError("rate must be positive") raise ValueError("dilation_rate must be positive")
if np.all(const_rate == 1): if np.all(const_rate == 1):
return op(input, num_spatial_dims, padding) return op(input, num_spatial_dims, padding)
@ -266,26 +342,88 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None):
raise ValueError("Invalid padding method %r" % padding) raise ValueError("Invalid padding method %r" % padding)
# Handle input whose shape is unknown during graph creation. # Handle input whose shape is unknown during graph creation.
if input.get_shape().is_fully_defined(): input_spatial_shape = None
input_shape = np.array(input.get_shape().as_list()) if input.get_shape().ndims is not None:
else: input_shape_list = input.get_shape().as_list()
input_shape = array_ops.shape(input) input_spatial_shape = [input_shape_list[i] for i in spatial_dims]
if input_spatial_shape is None or None in input_spatial_shape:
input_spatial_shape = array_ops.gather(array_ops.shape(input), spatial_dims)
input_spatial_shape = input_shape[1:num_spatial_dims+1]
paddings, crops = array_ops.required_space_to_batch_paddings( paddings, crops = array_ops.required_space_to_batch_paddings(
input_shape=input_spatial_shape, input_shape=input_spatial_shape,
base_paddings=base_paddings, base_paddings=base_paddings,
block_shape=dilation_rate) block_shape=dilation_rate)
input_converted = array_ops.space_to_batch_nd(input=input, def adjust(orig, fill_value):
block_shape=dilation_rate, """Returns an `adjusted` version of `orig` based on `spatial_dims`.
paddings=paddings)
Tensor of the same type as `orig` and with shape
`[max(spatial_dims), ...]` where:
adjusted[spatial_dims[i] - 1, ...] = orig[i, ...]
for 0 <= i < len(spatial_dims), and
adjusted[j, ...] = fill_value
for j != spatial_dims[i] - 1 for some i.
If `orig` is a constant value, then the result will be a constant value.
Args:
orig: Tensor of rank > max(spatial_dims).
fill_value: Numpy scalar (of same data type as `orig) specifying the fill
value for non-spatial dimensions.
Returns:
`adjusted` tensor.
"""
fill_dims = orig.get_shape().as_list()[1:]
dtype = orig.dtype.as_numpy_dtype
parts = []
const_orig = tensor_util.constant_value(orig)
const_or_orig = const_orig if const_orig is not None else orig
prev_spatial_dim = 0
i = 0
while i < len(spatial_dims):
start_i = i
start_spatial_dim = spatial_dims[i]
if start_spatial_dim > 1:
# Fill in any gap from the previous spatial dimension (or dimension 1 if
# this is the first spatial dimension) with `fill_value`.
parts.append(
np.full(
[start_spatial_dim - 1 - prev_spatial_dim] + fill_dims,
fill_value,
dtype=dtype))
# Find the largest value of i such that:
# [spatial_dims[start_i], ..., spatial_dims[i]]
# == [start_spatial_dim, ..., start_spatial_dim + i - start_i],
# i.e. the end of a contiguous group of spatial dimensions.
while (i + 1 < len(spatial_dims) and
spatial_dims[i + 1] == spatial_dims[i] + 1):
i += 1
parts.append(const_or_orig[start_i:i + 1])
prev_spatial_dim = spatial_dims[i]
i += 1
if const_orig is not None:
return np.concatenate(parts)
else:
return array_ops.concat(0, parts)
dilation_rate = adjust(dilation_rate, 1)
paddings = adjust(paddings, 0)
crops = adjust(crops, 0)
input_converted = array_ops.space_to_batch_nd(
input=input,
block_shape=dilation_rate,
paddings=paddings)
result = op(input_converted, num_spatial_dims, "VALID") result = op(input_converted, num_spatial_dims, "VALID")
result_converted = array_ops.batch_to_space_nd(input=result, result_converted = array_ops.batch_to_space_nd(
block_shape=dilation_rate, input=result, block_shape=dilation_rate, crops=crops)
crops=crops)
return result_converted return result_converted
@ -333,7 +471,8 @@ def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate):
def convolution(input, filter, # pylint: disable=redefined-builtin def convolution(input, filter, # pylint: disable=redefined-builtin
padding, strides=None, dilation_rate=None, name=None): padding, strides=None, dilation_rate=None,
name=None, data_format=None):
# pylint: disable=line-too-long # pylint: disable=line-too-long
"""Computes sums of N-D convolutions (actually cross-correlation). """Computes sums of N-D convolutions (actually cross-correlation).
@ -343,7 +482,8 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
the optional `dilation_rate` parameter. Currently, however, output striding the optional `dilation_rate` parameter. Currently, however, output striding
is not supported for atrous convolutions. is not supported for atrous convolutions.
Specifically, given rank (N+2) `input` Tensor of shape Specifically, in the case that `data_format` does not start with "NC", given
a rank (N+2) `input` Tensor of shape
[num_batches, [num_batches,
input_spatial_shape[0], input_spatial_shape[0],
@ -368,23 +508,34 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
sum_{z[0], ..., z[N-1], q} sum_{z[0], ..., z[N-1], q}
filters[z[0], ..., z[N-1], q, k] * filter[z[0], ..., z[N-1], q, k] *
padded_input[b, padded_input[b,
x[0]*strides[0] + dilation_rate[0]*z[0], x[0]*strides[0] + dilation_rate[0]*z[0],
..., ...,
x[N-1]*strides[N-1] + dilation_rate[N-1]*z[N-1], x[N-1]*strides[N-1] + dilation_rate[N-1]*z[N-1],
q], q]
where `padded_input` is obtained by zero padding the input using an effective where `padded_input` is obtained by zero padding the input using an effective
spatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` and spatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` and
output striding `strides` as described in the output striding `strides` as described in the
[comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution). [comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution).
In the case that `data_format` does start with `"NC"`, the `input` and output
(but not the `filter`) are simply transposed as follows:
convolution(input, data_format, **kwargs) =
tf.transpose(convolution(tf.transpose(input, [0] + range(2,N+2) + [1]),
**kwargs),
[0, N+1] + range(1, N+1))
It is required that 1 <= N <= 3. It is required that 1 <= N <= 3.
Args: Args:
input: An N-D `Tensor` of type `T`, of shape input: An N-D `Tensor` of type `T`, of shape
`[batch_size] + input_spatial_shape + [in_channels]`. `[batch_size] + input_spatial_shape + [in_channels]` if data_format does
not start with "NC" (default), or
`[batch_size, in_channels] + input_spatial_shape` if data_format starts
with "NC".
filter: An N-D `Tensor` with the same type as `input` and shape filter: An N-D `Tensor` with the same type as `input` and shape
`spatial_filter_shape + [in_channels, out_channels]`. `spatial_filter_shape + [in_channels, out_channels]`.
padding: A string, either `"VALID"` or `"SAME"`. The padding algorithm. padding: A string, either `"VALID"` or `"SAME"`. The padding algorithm.
@ -400,12 +551,23 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
filter in each spatial dimension i. If any value of dilation_rate is > 1, filter in each spatial dimension i. If any value of dilation_rate is > 1,
then all values of strides must be 1. then all values of strides must be 1.
name: Optional name for the returned tensor. name: Optional name for the returned tensor.
data_format: A string or None. Specifies whether the channel dimension of
the `input` and output is the last dimension (default, or if `data_format`
does not start with "NC"), or the second dimension (if `data_format`
starts with "NC"). For N=1, the valid values are "NWC" (default) and
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
N=3, the valid value is "NDHWC".
Returns: Returns:
A `Tensor` with the same type as `value` of shape A `Tensor` with the same type as `input` of shape
`[batch_size] + output_spatial_shape + [out_channels]`, `[batch_size] + output_spatial_shape + [out_channels]`
if data_format is None or does not start with "NC", or
`[batch_size, out_channels] + output_spatial_shape`
if data_format starts with "NC",
where `output_spatial_shape` depends on the value of `padding`. where `output_spatial_shape` depends on the value of `padding`.
If padding == "SAME": If padding == "SAME":
@ -418,8 +580,8 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
/ strides[i]). / strides[i]).
Raises: Raises:
ValueError: If input/output depth does not match `filter` shape, or if ValueError: If input/output depth does not match `filter` shape, if padding
padding is other than `"VALID"` or `"SAME"`. is other than `"VALID"` or `"SAME"`, or if data_format is invalid.
""" """
# pylint: enable=line-too-long # pylint: enable=line-too-long
@ -444,12 +606,19 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
except ValueError: except ValueError:
ValueError("filter tensor must have rank %d" % (num_spatial_dims + 2)) ValueError("filter tensor must have rank %d" % (num_spatial_dims + 2))
if not input.get_shape()[num_spatial_dims + 1].is_compatible_with( if data_format is None or not data_format.startswith("NC"):
filter.get_shape()[num_spatial_dims]): input_channels_dim = input.get_shape()[num_spatial_dims + 1]
spatial_dims = range(1, num_spatial_dims+1)
else:
input_channels_dim = input.get_shape()[1]
spatial_dims = range(2, num_spatial_dims+2)
if not input_channels_dim.is_compatible_with(filter.get_shape()[
num_spatial_dims]):
raise ValueError( raise ValueError(
"number of input channels does not match corresponding dimension of filter, " "number of input channels does not match corresponding dimension of filter, "
"{} != {}".format(input.get_shape()[num_spatial_dims + 1], "{} != {}".format(input_channels_dim, filter.get_shape()[
filter.get_shape()[num_spatial_dims])) num_spatial_dims]))
strides, dilation_rate = _get_strides_and_dilation_rate( strides, dilation_rate = _get_strides_and_dilation_rate(
num_spatial_dims, strides, dilation_rate) num_spatial_dims, strides, dilation_rate)
@ -459,12 +628,14 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
input=input_converted, input=input_converted,
filter=filter, filter=filter,
padding=padding, padding=padding,
data_format=data_format,
strides=strides, strides=strides,
name=name) name=name)
return with_space_to_batch( return with_space_to_batch(
input=input, input=input,
filter_shape=array_ops.shape(filter), filter_shape=array_ops.shape(filter),
spatial_dims=spatial_dims,
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
padding=padding, padding=padding,
op=op) op=op)
@ -476,11 +647,12 @@ def pool(input, # pylint: disable=redefined-builtin
padding, padding,
dilation_rate=None, dilation_rate=None,
strides=None, strides=None,
name=None): name=None,
data_format=None):
# pylint: disable=line-too-long # pylint: disable=line-too-long
"""Performs an N-D pooling operation. """Performs an N-D pooling operation.
Computes for In the case that `data_format` does not start with "NC", computes for
0 <= b < batch_size, 0 <= b < batch_size,
0 <= x[i] < output_spatial_shape[i], 0 <= x[i] < output_spatial_shape[i],
0 <= c < num_channels: 0 <= c < num_channels:
@ -498,10 +670,20 @@ def pool(input, # pylint: disable=redefined-builtin
[comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution). [comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution).
The reduction never includes out-of-bounds positions. The reduction never includes out-of-bounds positions.
In the case that `data_format` starts with `"NC"`, the `input` and output are
simply transposed as follows:
pool(input, data_format, **kwargs) =
tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
**kwargs),
[0, N+1] + range(1, N+1))
Args: Args:
input: Tensor of rank N+2, of shape input: Tensor of rank N+2, of shape
[batch_size] + input_spatial_shape + [num_channels]. `[batch_size] + input_spatial_shape + [num_channels]` if data_format does
Pooling happens over the spatial dimensions only. not start with "NC" (default), or
`[batch_size, num_channels] + input_spatial_shape` if data_format starts
with "NC". Pooling happens over the spatial dimensions only.
window_shape: Sequence of N ints >= 1. window_shape: Sequence of N ints >= 1.
pooling_type: Specifies pooling operation, must be "AVG" or "MAX". pooling_type: Specifies pooling operation, must be "AVG" or "MAX".
padding: The padding algorithm, must be "SAME" or "VALID". padding: The padding algorithm, must be "SAME" or "VALID".
@ -513,10 +695,22 @@ def pool(input, # pylint: disable=redefined-builtin
If any value of strides is > 1, then all values of dilation_rate must be If any value of strides is > 1, then all values of dilation_rate must be
1. 1.
name: Optional. Name of the op. name: Optional. Name of the op.
data_format: A string or None. Specifies whether the channel dimension of
the `input` and output is the last dimension (default, or if `data_format`
does not start with "NC"), or the second dimension (if `data_format`
starts with "NC"). For N=1, the valid values are "NWC" (default) and
"NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
N=3, the valid value is "NDHWC".
Returns: Returns:
Tensor of rank N+2, of shape Tensor of rank N+2, of shape
[batch_size] + output_spatial_shape + [num_channels], [batch_size] + output_spatial_shape + [num_channels]
if data_format is None or does not start with "NC", or
[batch_size, num_channels] + output_spatial_shape
if data_format starts with "NC",
where `output_spatial_shape` depends on the value of padding: where `output_spatial_shape` depends on the value of padding:
If padding = "SAME": If padding = "SAME":
@ -536,8 +730,8 @@ def pool(input, # pylint: disable=redefined-builtin
input = ops.convert_to_tensor(input, name="input") input = ops.convert_to_tensor(input, name="input")
num_spatial_dims = len(window_shape) num_spatial_dims = len(window_shape)
if num_spatial_dims < 2 or num_spatial_dims > 3: if num_spatial_dims < 1 or num_spatial_dims > 3:
raise ValueError("It is required that 2 <= num_spatial_dims <= 3.") raise ValueError("It is required that 1 <= num_spatial_dims <= 3.")
input.get_shape().with_rank(num_spatial_dims + 2) input.get_shape().with_rank(num_spatial_dims + 2)
@ -553,8 +747,10 @@ def pool(input, # pylint: disable=redefined-builtin
"strides > window_shape not supported due to inconsistency between " "strides > window_shape not supported due to inconsistency between "
"CPU and GPU implementations") "CPU and GPU implementations")
pooling_ops = {("MAX", 2): max_pool, pooling_ops = {("MAX", 1): max_pool,
("MAX", 2): max_pool,
("MAX", 3): max_pool3d, # pylint: disable=undefined-variable ("MAX", 3): max_pool3d, # pylint: disable=undefined-variable
("AVG", 1): avg_pool,
("AVG", 2): avg_pool, ("AVG", 2): avg_pool,
("AVG", 3): avg_pool3d, # pylint: disable=undefined-variable ("AVG", 3): avg_pool3d, # pylint: disable=undefined-variable
} }
@ -563,18 +759,52 @@ def pool(input, # pylint: disable=redefined-builtin
raise ValueError("%d-D %s pooling is not supported." % raise ValueError("%d-D %s pooling is not supported." %
(op_key[1], op_key[0])) (op_key[1], op_key[0]))
def op(converted_input, _, converted_padding): if data_format is None or not data_format.startswith("NC"):
return pooling_ops[op_key](converted_input, adjusted_window_shape = [1] + list(window_shape) + [1]
[1] + list(window_shape) + [1], adjusted_strides = [1] + list(strides) + [1]
[1] + list(strides) + [1], spatial_dims = range(1, num_spatial_dims + 1)
converted_padding, else:
name=scope) adjusted_window_shape = [1, 1] + list(window_shape)
adjusted_strides = [1, 1] + list(strides)
spatial_dims = range(2, num_spatial_dims + 2)
return with_space_to_batch(input=input, if num_spatial_dims == 3:
dilation_rate=dilation_rate, if data_format is not None and data_format != "NDHWC":
padding=padding, raise ValueError("data_format must be \"NDHWC\".")
op=op, data_format_kwargs = dict()
filter_shape=window_shape) elif num_spatial_dims == 1:
if data_format is None or data_format == "NWC":
data_format_kwargs = dict(data_format="NHWC")
elif data_format == "NCW":
data_format_kwargs = dict(data_format="NCHW")
else:
raise ValueError("data_format must be either \"NWC\" or \"NCW\".")
adjusted_window_shape = [1] + adjusted_window_shape
adjusted_strides = [1] + adjusted_strides
else:
data_format_kwargs = dict(data_format=data_format)
def op(converted_input, _, converted_padding): # pylint: disable=missing-docstring
if num_spatial_dims == 1:
converted_input = array_ops.expand_dims(converted_input,
spatial_dims[0])
result = pooling_ops[op_key](converted_input,
adjusted_window_shape,
adjusted_strides,
converted_padding,
name=scope,
**data_format_kwargs)
if num_spatial_dims == 1:
result = array_ops.squeeze(result, [spatial_dims[0]])
return result
return with_space_to_batch(
input=input,
dilation_rate=dilation_rate,
padding=padding,
op=op,
spatial_dims=spatial_dims,
filter_shape=window_shape)
def atrous_conv2d(value, filters, rate, padding, name=None): def atrous_conv2d(value, filters, rate, padding, name=None):
@ -1794,19 +2024,27 @@ def conv1d(value, filters, stride, padding,
name=None): name=None):
"""Computes a 1-D convolution given 3-D input and filter tensors. """Computes a 1-D convolution given 3-D input and filter tensors.
Given an input tensor of shape [batch, in_width, in_channels] Given an input tensor of shape
[batch, in_width, in_channels]
if data_format is "NHWC", or
[batch, in_channels, in_width]
if data_format is "NCHW",
and a filter / kernel tensor of shape and a filter / kernel tensor of shape
[filter_width, in_channels, out_channels], this op reshapes [filter_width, in_channels, out_channels], this op reshapes
the arguments to pass them to conv2d to perform the equivalent the arguments to pass them to conv2d to perform the equivalent
convolution operation. convolution operation.
Internally, this op reshapes the input tensors and invokes Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
`tf.nn.conv2d`. A tensor of shape [batch, in_width, in_channels] For example, if `data_format` does not start with "NC", a tensor of shape
is reshaped to [batch, 1, in_width, in_channels], and the filter [batch, in_width, in_channels]
is reshaped to [1, filter_width, in_channels, out_channels]. is reshaped to
The result is then reshaped back to [batch, out_width, out_channels] [batch, 1, in_width, in_channels],
(where out_width is a function of the stride and padding as in and the filter is reshaped to
conv2d) and returned to the caller. [1, filter_width, in_channels, out_channels].
The result is then reshaped back to
[batch, out_width, out_channels]
(where out_width is a function of the stride and padding as in conv2d) and
returned to the caller.
Args: Args:
value: A 3D `Tensor`. Must be of type `float32` or `float64`. value: A 3D `Tensor`. Must be of type `float32` or `float64`.
@ -1823,16 +2061,27 @@ def conv1d(value, filters, stride, padding,
Returns: Returns:
A `Tensor`. Has the same type as input. A `Tensor`. Has the same type as input.
Raises:
ValueError: if `data_format` is invalid.
""" """
with ops.name_scope(name, "conv1d", [value, filters]) as name: with ops.name_scope(name, "conv1d", [value, filters]) as name:
# Reshape the input tensor to [batch, 1, in_width, in_channels] # Reshape the input tensor to [batch, 1, in_width, in_channels]
value = array_ops.expand_dims(value, 1) if data_format is None or data_format == "NHWC":
# And reshape the filter to [1, filter_width, in_channels, out_channels] data_format = "NHWC"
spatial_start_dim = 1
strides = [1, 1, stride, 1]
elif data_format == "NCHW":
spatial_start_dim = 2
strides = [1, 1, 1, stride]
else:
raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
value = array_ops.expand_dims(value, spatial_start_dim)
filters = array_ops.expand_dims(filters, 0) filters = array_ops.expand_dims(filters, 0)
result = gen_nn_ops.conv2d(value, filters, [1, 1, stride, 1], padding, result = gen_nn_ops.conv2d(value, filters, strides, padding,
use_cudnn_on_gpu=use_cudnn_on_gpu, use_cudnn_on_gpu=use_cudnn_on_gpu,
data_format=data_format) data_format=data_format)
return array_ops.squeeze(result, [1]) return array_ops.squeeze(result, [spatial_start_dim])
ops.RegisterShape("Dilation2D")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("Dilation2D")(common_shapes.call_cpp_shape_fn)