[TF] Add support for more than one outer batch dimension to tf.nn.convolution.

This is part 2/N of adding outer batch dimension support to tf.nn.convXd and keras.layers.ConvXd.

Also added support for batch_shape.ndims > 1 to nn_ops.Convolution and other internal
libraries, so that we can use this in keras.layers.ConvXD.

For now, using tf.nn.convolution with filter.shape == 3 or filter.shape == 5 (conv1d or conv3d) still raises an error deep in the ops, because i haven't yet added reshape
wrappers for gen_nn_ops.conv{1d,3d} but those are gonna be easy to add once
this is in.  I wanted to make sure it works for conv2d first.

No public signature changes.

Rollback of rollback with fixes.

PiperOrigin-RevId: 312735044
Change-Id: I4b4497a2925a965fa45f1812d7bd25d7a2c087ac
This commit is contained in:
Eugene Brevdo 2020-05-21 13:53:55 -07:00 committed by TensorFlower Gardener
parent 57c5d33f89
commit 7d9d943192
2 changed files with 297 additions and 81 deletions
tensorflow/python

View File

@ -455,6 +455,58 @@ class Conv2DTest(test.TestCase):
conv1,
self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConvolutionClass2DExpandedBatch(self):
tensor_in_sizes_batch = [10, 2, 3, 3]
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3]
filter_in_sizes = [1, 1, 3, 3]
filter_in = self._CreateNumpyTensor(filter_in_sizes)
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
convolver1 = nn_ops.Convolution(
input_shape=x1.shape,
filter_shape=filter_in.shape,
strides=[1, 1],
padding="VALID")
self.assertEqual(convolver1.num_batch_dims, 1)
convolver2 = nn_ops.Convolution(
input_shape=x2.shape,
filter_shape=filter_in.shape,
strides=[1, 1],
padding="VALID")
self.assertEqual(convolver2.num_batch_dims, 2)
conv1 = convolver1(x1, filter_in)
conv2 = convolver2(x2, filter_in)
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
self.assertAllEqual(
conv1,
self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self):
tensor_in_sizes_batch = [10, 2, 3, 3]
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3]
filter_in_sizes = [1, 1, 3, 3]
filter_in = self._CreateNumpyTensor(filter_in_sizes)
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
conv1 = nn_ops.convolution(
x1,
filter_in,
strides=[1, 1],
padding="VALID")
conv2 = nn_ops.convolution(
x2,
filter_in,
strides=[1, 1],
padding="VALID")
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
self.assertAllEqual(
conv1,
self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConv2D2x2Filter2x1Dilation(self):
self._VerifyDilatedConvValues(

View File

@ -131,9 +131,9 @@ def _non_atrous_convolution(
"""
with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope:
input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin
input_shape = input.get_shape()
input_shape = input.shape
filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin
filter_shape = filter.get_shape()
filter_shape = filter.shape
op = _NonAtrousConvolution(
input_shape,
filter_shape=filter_shape,
@ -148,36 +148,51 @@ class _NonAtrousConvolution(object):
"""Helper class for _non_atrous_convolution.
Note that this class assumes that shapes of input and filter passed to
__call__ are compatible with input_shape and filter_shape passed to the
`__call__` are compatible with `input_shape` and filter_shape passed to the
constructor.
Arguments:
input_shape: static input shape, i.e. input.get_shape().
filter_shape: static filter shape, i.e. filter.get_shape().
input_shape: static input shape, i.e. input.shape.
filter_shape: static filter shape, i.e. filter.shape.
padding: see _non_atrous_convolution.
data_format: see _non_atrous_convolution.
strides: see _non_atrous_convolution.
name: see _non_atrous_convolution.
num_batch_dims: (Optional.) The number of batch dimensions in the input;
if not provided, the default of `1` is used.
"""
def __init__(
self,
input_shape,
filter_shape, # pylint: disable=redefined-builtin
filter_shape,
padding,
data_format=None,
strides=None,
name=None):
filter_shape = filter_shape.with_rank(input_shape.ndims)
name=None,
num_batch_dims=1):
# filter shape is always rank num_spatial_dims + 2
# and num_spatial_dims == input_shape.ndims - num_batch_dims - 1
if input_shape.ndims is not None:
filter_shape = filter_shape.with_rank(
input_shape.ndims - num_batch_dims + 1)
self.padding = padding
self.name = name
input_shape = input_shape.with_rank(filter_shape.ndims)
# input shape is == num_spatial_dims + num_batch_dims + 1
# and filter_shape is always rank num_spatial_dims + 2
if filter_shape.ndims is not None:
input_shape = input_shape.with_rank(
filter_shape.ndims + num_batch_dims - 1)
if input_shape.ndims is None:
raise ValueError("Rank of convolution must be known")
if input_shape.ndims < 3 or input_shape.ndims > 5:
raise ValueError(
"`input` and `filter` must have rank at least 3 and at most 5")
conv_dims = input_shape.ndims - 2
"Rank of convolution must be known, but saw input_shape.ndims == {}"
.format(input_shape.ndims))
if input_shape.ndims < 3 or input_shape.ndims - num_batch_dims + 1 > 5:
raise ValueError(
"`input_shape.ndims - num_batch_dims + 1` must be at least 3 and at "
"most 5 but saw `input_shape.ndims == {}` and `num_batch_dims == {}`"
.format(input_shape.ndims, num_batch_dims))
conv_dims = input_shape.ndims - num_batch_dims - 1
if strides is None:
strides = [1] * conv_dims
elif len(strides) != conv_dims:
@ -520,7 +535,7 @@ def with_space_to_batch(
"""
input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin
input_shape = input.get_shape()
input_shape = input.shape
def build_op(num_spatial_dims, padding):
return lambda inp, _: op(inp, num_spatial_dims, padding)
@ -540,18 +555,19 @@ class _WithSpaceToBatch(object):
"""Helper class for with_space_to_batch.
Note that this class assumes that shapes of input and filter passed to
__call__ are compatible with input_shape and filter_shape passed to the
constructor.
`__call__` are compatible with `input_shape`, `filter_shape`, and
`spatial_dims` passed to the constructor.
Arguments
input_shape: static shape of input. i.e. input.get_shape().
dilation_rate: see with_space_to_batch
padding: see with_space_to_batch
input_shape: static shape of input. i.e. input.shape.
dilation_rate: see `with_space_to_batch`.
padding: see `with_space_to_batch`.
build_op: Function that maps (num_spatial_dims, paddings) -> (function that
maps (input, filter) -> output).
filter_shape: see with_space_to_batch
spatial_dims: see with_space_to_batch
data_format: see with_space_to_batch
filter_shape: see `with_space_to_batch`.
spatial_dims: `see with_space_to_batch`.
data_format: see `with_space_to_batch`.
num_batch_dims: (Optional). Number of batch dims in `input_shape`.
"""
def __init__(self,
@ -561,24 +577,25 @@ class _WithSpaceToBatch(object):
build_op,
filter_shape=None,
spatial_dims=None,
data_format=None):
data_format=None,
num_batch_dims=1):
"""Helper class for _with_space_to_batch."""
dilation_rate = ops.convert_to_tensor(
dilation_rate, dtypes.int32, name="dilation_rate")
try:
rate_shape = dilation_rate.get_shape().with_rank(1)
except ValueError:
raise ValueError("rate must be rank 1")
if dilation_rate.shape.ndims not in (None, 1):
raise ValueError(
"rate must be rank 1 but saw {}".format(dilation_rate.shape.ndims))
if not dilation_rate.get_shape().is_fully_defined():
raise ValueError("rate must have known shape")
if not dilation_rate.shape.is_fully_defined():
raise ValueError("rate must have known shape, but saw {}"
.format(dilation_rate.shape))
num_spatial_dims = rate_shape.dims[0].value
num_spatial_dims = dilation_rate.shape.dims[0].value
if data_format is not None and data_format.startswith("NC"):
starting_spatial_dim = 2
starting_spatial_dim = num_batch_dims + 1
else:
starting_spatial_dim = 1
starting_spatial_dim = num_batch_dims
if spatial_dims is None:
spatial_dims = range(starting_spatial_dim,
@ -588,7 +605,7 @@ class _WithSpaceToBatch(object):
if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
raise ValueError(
"spatial_dims must be a monotonically increasing sequence of "
"positive integers")
"positive integers, but saw: {}".format(orig_spatial_dims))
if data_format is not None and data_format.startswith("NC"):
expected_input_rank = spatial_dims[-1]
@ -599,14 +616,16 @@ class _WithSpaceToBatch(object):
input_shape.with_rank_at_least(expected_input_rank)
except ValueError:
raise ValueError(
"input tensor must have rank %d at least" % (expected_input_rank))
"input tensor must have rank at least {}, but saw rank {}"
.format(expected_input_rank, input_shape.ndims))
const_rate = tensor_util.constant_value(dilation_rate)
rate_or_const_rate = dilation_rate
if const_rate is not None:
rate_or_const_rate = const_rate
if np.any(const_rate < 1):
raise ValueError("dilation_rate must be positive")
raise ValueError("dilation_rate must be positive, but saw: {}"
.format(const_rate))
if np.all(const_rate == 1):
self.call = build_op(num_spatial_dims, padding)
return
@ -672,6 +691,7 @@ class _WithSpaceToBatch(object):
filter_shape = array_ops.shape(filter)
base_paddings = _with_space_to_batch_base_paddings(
filter_shape, self.num_spatial_dims, self.rate_or_const_rate)
paddings, crops = array_ops.required_space_to_batch_paddings(
input_shape=input_spatial_shape,
base_paddings=base_paddings,
@ -994,31 +1014,84 @@ def convolution_internal(
data_format=None,
dilations=None,
name=None,
call_from_convolution=True):
"""Internal function which performs rank agnostic convolution."""
if isinstance(input.shape, tensor_shape.TensorShape) and \
input.shape.rank is not None:
n = len(input.shape) - 2
elif not isinstance(input.shape, tensor_shape.TensorShape) and \
input.shape is not None:
n = len(input.shape) - 2
elif isinstance(filters.shape, tensor_shape.TensorShape) and \
filters.shape.rank is not None:
call_from_convolution=True,
num_spatial_dims=None):
"""Internal function which performs rank agnostic convolution.
Args:
input: See `convolution`.
filters: See `convolution`.
strides: See `convolution`.
padding: See `convolution`.
data_format: See `convolution`.
dilations: See `convolution`.
name: See `convolution`.
call_from_convolution: See `convolution`.
num_spatial_dims: (Optional.). It is a integer describing the
rank of the spatial dimensions. For `1-D`, `2-D` and `3-D` convolutions,
the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
This argument is only required to disambiguate the rank of `batch_shape`
when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For
backwards compatibility, if `num_spatial_dims is None` and
`filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
`1` (i.e., the input is expected to be
`[batch_size, num_channels] + input_spatial_shape`
or `[batch_size] + input_spatial_shape + [num_channels]`.
Returns:
A tensor of shape and dtype matching that of `input`.
Raises:
ValueError: If input and filter both have unknown shapes, or if
`num_spatial_dims` is provided and incompatible with the value
estimated from `filters.shape`.
"""
n = None
if getattr(filters, 'shape', None) is None:
with ops.name_scope(name, 'convolution_internal', [filters, input]):
filters = ops.convert_to_tensor(filters, name='filters')
if (isinstance(filters.shape, tensor_shape.TensorShape)
and filters.shape.rank is not None):
n = len(filters.shape) - 2
elif not isinstance(filters.shape, tensor_shape.TensorShape) and \
filters.shape is not None:
elif (not isinstance(filters.shape, tensor_shape.TensorShape)
and filters.shape is not None):
n = len(filters.shape) - 2
if (isinstance(input.shape, tensor_shape.TensorShape)
and input.shape.rank is not None):
if n is None:
n = (num_spatial_dims if num_spatial_dims is not None
else len(input.shape) - 2)
num_batch_dims = len(input.shape) - n - 1
elif (not isinstance(input.shape, tensor_shape.TensorShape)
and input.shape is not None):
if n is None:
n = (num_spatial_dims if num_spatial_dims is not None
else len(input.shape) - 2)
num_batch_dims = len(input.shape) - n - 1
else:
num_batch_dims = 1 # Default behavior if it cannot be estimated.
if n is None:
raise ValueError("rank of input or filter must be known")
if num_spatial_dims is not None and n != num_spatial_dims:
raise ValueError(
"inconsistent estimate of spatial dims ({}) vs. actual passed "
"num_spatial_dims ({}). n was estimated as len(filters.shape) - 2, "
"but filters shape is: {}".format(n, num_spatial_dims, filters.shape))
if not 1 <= n <= 3:
raise ValueError(
"Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one "
"of 1, 2 or 3 but saw {}. num_batch_dims: {}."
.format(n, num_batch_dims))
if data_format is None:
channel_index = n + 1
channel_index = num_batch_dims + n
else:
channel_index = 1 if data_format.startswith("NC") else n + 1
channel_index = (
num_batch_dims if data_format.startswith("NC") else n + num_batch_dims)
strides = _get_sequence(strides, n, channel_index, "strides")
dilations = _get_sequence(dilations, n, channel_index, "dilations")
@ -1031,7 +1104,7 @@ def convolution_internal(
scope = "convolution"
with ops.name_scope(name, scope, [input, filters]) as name:
conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: gen_nn_ops.conv3d}
if device_context.enclosing_tpu_context() is not None or all(
i == 1 for i in dilations):
@ -1061,7 +1134,8 @@ def convolution_internal(
strides=strides,
dilation_rate=dilations,
name=name,
data_format=data_format)
data_format=data_format,
num_spatial_dims=n)
return op(input, filters)
@ -1069,17 +1143,34 @@ class Convolution(object):
"""Helper class for convolution.
Note that this class assumes that shapes of input and filter passed to
__call__ are compatible with input_shape and filter_shape passed to the
constructor.
`__call__` are compatible with `input_shape`, `filter_shape`, and
`num_spatial_dims` passed to the constructor.
Arguments
input_shape: static shape of input. i.e. input.get_shape().
filter_shape: static shape of the filter. i.e. filter.get_shape().
padding: see convolution.
input_shape: static shape of input. i.e. input.shape. Its length is
`batch_shape + input_spatial_shape + [num_channels]` if `data_format`
does not start with `NC`, or
`batch_shape + [num_channels] + input_spatial_shape` if `data_format`
starts with `NC`.
filter_shape: static shape of the filter. i.e. filter.shape.
padding: The padding algorithm, must be "SAME" or "VALID".
strides: see convolution.
dilation_rate: see convolution.
name: see convolution.
data_format: see convolution.
data_format: A string or `None`. Specifies whether the channel dimension of
the `input` and output is the last dimension (if `data_format` is `None`
or does not start with `NC`), or the first post-batch dimension (i.e. if
`data_format` starts with `NC`).
num_spatial_dims: (Usually optional.) Python integer, the rank of the
spatial and channel dimensions. For `1-D`, `2-D` and `3-D` convolutions,
the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
This argument is only required to disambiguate the rank of `batch_shape`
when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For
backwards compatibility, if `num_spatial_dims is None` and
`filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
`1` (i.e., the input is expected to be
`[batch_size, num_channels] + input_spatial_shape`
or `[batch_size] + input_spatial_shape + [num_channels]`.
"""
def __init__(self,
@ -1089,40 +1180,72 @@ class Convolution(object):
strides=None,
dilation_rate=None,
name=None,
data_format=None):
data_format=None,
num_spatial_dims=None):
"""Helper function for convolution."""
num_total_dims = filter_shape.ndims
if num_total_dims is None:
num_total_dims = input_shape.ndims
if num_total_dims is None:
raise ValueError("rank of input or filter must be known")
num_batch_dims = None
filter_shape = tensor_shape.as_shape(filter_shape)
input_shape = tensor_shape.as_shape(input_shape)
num_spatial_dims = num_total_dims - 2
if filter_shape.ndims is not None:
if (num_spatial_dims is not None and
filter_shape.ndims != num_spatial_dims + 2):
raise ValueError(
"Expected filter_shape.ndims == num_spatial_dims + 2, "
"but saw filter_shape.ndims == {} and num_spatial_dims == {}"
.format(filter_shape.ndims, num_spatial_dims))
else:
num_spatial_dims = filter_shape.ndims - 2
try:
input_shape.with_rank(num_spatial_dims + 2)
except ValueError:
if input_shape.ndims is not None and num_spatial_dims is not None:
num_batch_dims = input_shape.ndims - num_spatial_dims - 1
if num_spatial_dims is None:
num_spatial_dims = input_shape.ndims - 2
else:
if input_shape.ndims is not None:
if input_shape.ndims < num_spatial_dims + 2:
raise ValueError(
"Expected input_shape.ndims >= num_spatial_dims + 2, but saw "
"input_shape.ndims == {} and num_spatial_dims == {}"
.format(input_shape.ndims, num_spatial_dims))
else:
if num_batch_dims is None:
num_batch_dims = input_shape.ndims - num_spatial_dims - 1
if num_spatial_dims is None:
raise ValueError(
"input tensor must have rank %d" % (num_spatial_dims + 2))
"Cannot estimate num_spatial_dims since input_shape.ndims is None, "
"filter_shape.ndims is None, and argument num_spatial_dims is also "
"None.")
try:
filter_shape.with_rank(num_spatial_dims + 2)
except ValueError:
if num_batch_dims is None:
num_batch_dims = 1
if num_batch_dims < 1:
raise ValueError(
"filter tensor must have rank %d" % (num_spatial_dims + 2))
"num_batch_dims should be >= 1, but saw {}. num_batch_dims was "
"estimated as `input_shape.ndims - num_spatial_dims - 1` and "
"num_spatial_dims was either provided or estimated as "
"`filter_shape.ndims - 2`. input_shape.ndims: {}, "
"num_spatial_dims: {}, filter_shape.ndims: {}"
.format(num_batch_dims, input_shape.ndims, num_spatial_dims,
filter_shape.ndims))
if data_format is None or not data_format.startswith("NC"):
input_channels_dim = tensor_shape.dimension_at_index(
input_shape, num_spatial_dims + 1)
spatial_dims = range(1, num_spatial_dims + 1)
input_shape, num_spatial_dims + num_batch_dims)
spatial_dims = range(num_batch_dims, num_spatial_dims + num_batch_dims)
else:
input_channels_dim = tensor_shape.dimension_at_index(input_shape, 1)
spatial_dims = range(2, num_spatial_dims + 2)
input_channels_dim = tensor_shape.dimension_at_index(
input_shape, num_batch_dims)
spatial_dims = range(
num_batch_dims + 1, num_spatial_dims + num_batch_dims + 1)
if not input_channels_dim.is_compatible_with(
filter_shape[num_spatial_dims]):
raise ValueError(
"number of input channels does not match corresponding dimension of "
"Number of input channels does not match corresponding dimension of "
"filter, {} != {}".format(input_channels_dim,
filter_shape[num_spatial_dims]))
@ -1136,6 +1259,8 @@ class Convolution(object):
self.padding = padding
self.name = name
self.dilation_rate = dilation_rate
self.num_batch_dims = num_batch_dims
self.num_spatial_dims = num_spatial_dims
self.conv_op = _WithSpaceToBatch(
input_shape,
dilation_rate=dilation_rate,
@ -1143,7 +1268,8 @@ class Convolution(object):
build_op=self._build_op,
filter_shape=filter_shape,
spatial_dims=spatial_dims,
data_format=data_format)
data_format=data_format,
num_batch_dims=num_batch_dims)
def _build_op(self, _, padding):
return _NonAtrousConvolution(
@ -1152,7 +1278,8 @@ class Convolution(object):
padding=padding,
data_format=self.data_format,
strides=self.strides,
name=self.name)
name=self.name,
num_batch_dims=self.num_batch_dims)
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
# TPU convolution supports dilations greater than 1.
@ -1165,7 +1292,8 @@ class Convolution(object):
data_format=self.data_format,
dilations=self.dilation_rate,
name=self.name,
call_from_convolution=False)
call_from_convolution=False,
num_spatial_dims=self.num_spatial_dims)
else:
return self.conv_op(inp, filter)
@ -2392,6 +2520,42 @@ def conv2d_transpose_v2(
name=name)
def _conv2d_expanded_batch(
input, # pylint: disable=redefined-builtin
filters,
strides,
padding,
data_format,
dilations,
name):
"""Helper function for `convolution_internal`; handles expanded batches."""
# Try really hard to avoid modifying the legacy name scopes - return early.
shape = getattr(input, "shape", None)
if shape is not None:
ndims = getattr(shape, "ndims", -1)
if ndims == -1: ndims = len(shape)
if ndims in (4, 3, 2, 1, 0, None):
return gen_nn_ops.conv2d(
input,
filter=filters,
strides=strides,
padding=padding,
data_format=data_format,
dilations=dilations,
name=name)
return _squeeze_batch_dims(
input,
functools.partial(
gen_nn_ops.conv2d,
filter=filters,
strides=strides,
padding=padding,
data_format=data_format,
dilations=dilations),
inner_rank=3,
name=name)
@tf_export("nn.atrous_conv2d_transpose")
@dispatch.add_dispatch_support
def atrous_conv2d_transpose(value,