[TF] Add support for more than one outer batch dimension to tf.nn.convolution.

This is part 2/N of adding outer batch dimension support to tf.nn.convXd and keras.layers.ConvXd.

Also added support for batch_shape.ndims > 1 to nn_ops.Convolution and other internal
libraries, so that we can use this in keras.layers.ConvXD.

For now, using tf.nn.convolution with filter.shape == 3 or filter.shape == 5 (conv1d or conv3d) still raises an error deep in the ops, because i haven't yet added reshape
wrappers for gen_nn_ops.conv{1d,3d} but those are gonna be easy to add once
this is in.  I wanted to make sure it works for conv2d first.

No public signature changes.

PiperOrigin-RevId: 312697999
Change-Id: I01107967101f28b9906074b3c88664a3a09e8c4b
This commit is contained in:
Eugene Brevdo 2020-05-21 10:44:03 -07:00 committed by TensorFlower Gardener
parent c2534e2336
commit 37b60af536
2 changed files with 296 additions and 81 deletions

View File

@ -455,6 +455,58 @@ class Conv2DTest(test.TestCase):
conv1, conv1,
self.evaluate(conv2).reshape(conv1.shape)) self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConvolutionClass2DExpandedBatch(self):
tensor_in_sizes_batch = [10, 2, 3, 3]
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3]
filter_in_sizes = [1, 1, 3, 3]
filter_in = self._CreateNumpyTensor(filter_in_sizes)
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
convolver1 = nn_ops.Convolution(
input_shape=x1.shape,
filter_shape=filter_in.shape,
strides=[1, 1],
padding="VALID")
self.assertEqual(convolver1.num_batch_dims, 1)
convolver2 = nn_ops.Convolution(
input_shape=x2.shape,
filter_shape=filter_in.shape,
strides=[1, 1],
padding="VALID")
self.assertEqual(convolver2.num_batch_dims, 2)
conv1 = convolver1(x1, filter_in)
conv2 = convolver2(x2, filter_in)
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
self.assertAllEqual(
conv1,
self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self):
tensor_in_sizes_batch = [10, 2, 3, 3]
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3]
filter_in_sizes = [1, 1, 3, 3]
filter_in = self._CreateNumpyTensor(filter_in_sizes)
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
conv1 = nn_ops.convolution(
x1,
filter_in,
strides=[1, 1],
padding="VALID")
conv2 = nn_ops.convolution(
x2,
filter_in,
strides=[1, 1],
padding="VALID")
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
self.assertAllEqual(
conv1,
self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testConv2D2x2Filter2x1Dilation(self): def testConv2D2x2Filter2x1Dilation(self):
self._VerifyDilatedConvValues( self._VerifyDilatedConvValues(

View File

@ -131,9 +131,9 @@ def _non_atrous_convolution(
""" """
with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope: with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope:
input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin
input_shape = input.get_shape() input_shape = input.shape
filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin
filter_shape = filter.get_shape() filter_shape = filter.shape
op = _NonAtrousConvolution( op = _NonAtrousConvolution(
input_shape, input_shape,
filter_shape=filter_shape, filter_shape=filter_shape,
@ -148,36 +148,51 @@ class _NonAtrousConvolution(object):
"""Helper class for _non_atrous_convolution. """Helper class for _non_atrous_convolution.
Note that this class assumes that shapes of input and filter passed to Note that this class assumes that shapes of input and filter passed to
__call__ are compatible with input_shape and filter_shape passed to the `__call__` are compatible with `input_shape` and filter_shape passed to the
constructor. constructor.
Arguments: Arguments:
input_shape: static input shape, i.e. input.get_shape(). input_shape: static input shape, i.e. input.shape.
filter_shape: static filter shape, i.e. filter.get_shape(). filter_shape: static filter shape, i.e. filter.shape.
padding: see _non_atrous_convolution. padding: see _non_atrous_convolution.
data_format: see _non_atrous_convolution. data_format: see _non_atrous_convolution.
strides: see _non_atrous_convolution. strides: see _non_atrous_convolution.
name: see _non_atrous_convolution. name: see _non_atrous_convolution.
num_batch_dims: (Optional.) The number of batch dimensions in the input;
if not provided, the default of `1` is used.
""" """
def __init__( def __init__(
self, self,
input_shape, input_shape,
filter_shape, # pylint: disable=redefined-builtin filter_shape,
padding, padding,
data_format=None, data_format=None,
strides=None, strides=None,
name=None): name=None,
filter_shape = filter_shape.with_rank(input_shape.ndims) num_batch_dims=1):
# filter shape is always rank num_spatial_dims + 2
# and num_spatial_dims == input_shape.ndims - num_batch_dims - 1
if input_shape.ndims is not None:
filter_shape = filter_shape.with_rank(
input_shape.ndims - num_batch_dims + 1)
self.padding = padding self.padding = padding
self.name = name self.name = name
input_shape = input_shape.with_rank(filter_shape.ndims) # input shape is == num_spatial_dims + num_batch_dims + 1
# and filter_shape is always rank num_spatial_dims + 2
if filter_shape.ndims is not None:
input_shape = input_shape.with_rank(
filter_shape.ndims + num_batch_dims - 1)
if input_shape.ndims is None: if input_shape.ndims is None:
raise ValueError("Rank of convolution must be known")
if input_shape.ndims < 3 or input_shape.ndims > 5:
raise ValueError( raise ValueError(
"`input` and `filter` must have rank at least 3 and at most 5") "Rank of convolution must be known, but saw input_shape.ndims == {}"
conv_dims = input_shape.ndims - 2 .format(input_shape.ndims))
if input_shape.ndims < 3 or input_shape.ndims - num_batch_dims + 1 > 5:
raise ValueError(
"`input_shape.ndims - num_batch_dims + 1` must be at least 3 and at "
"most 5 but saw `input_shape.ndims == {}` and `num_batch_dims == {}`"
.format(input_shape.ndims, num_batch_dims))
conv_dims = input_shape.ndims - num_batch_dims - 1
if strides is None: if strides is None:
strides = [1] * conv_dims strides = [1] * conv_dims
elif len(strides) != conv_dims: elif len(strides) != conv_dims:
@ -520,7 +535,7 @@ def with_space_to_batch(
""" """
input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin
input_shape = input.get_shape() input_shape = input.shape
def build_op(num_spatial_dims, padding): def build_op(num_spatial_dims, padding):
return lambda inp, _: op(inp, num_spatial_dims, padding) return lambda inp, _: op(inp, num_spatial_dims, padding)
@ -540,18 +555,19 @@ class _WithSpaceToBatch(object):
"""Helper class for with_space_to_batch. """Helper class for with_space_to_batch.
Note that this class assumes that shapes of input and filter passed to Note that this class assumes that shapes of input and filter passed to
__call__ are compatible with input_shape and filter_shape passed to the `__call__` are compatible with `input_shape`, `filter_shape`, and
constructor. `spatial_dims` passed to the constructor.
Arguments Arguments
input_shape: static shape of input. i.e. input.get_shape(). input_shape: static shape of input. i.e. input.shape.
dilation_rate: see with_space_to_batch dilation_rate: see `with_space_to_batch`.
padding: see with_space_to_batch padding: see `with_space_to_batch`.
build_op: Function that maps (num_spatial_dims, paddings) -> (function that build_op: Function that maps (num_spatial_dims, paddings) -> (function that
maps (input, filter) -> output). maps (input, filter) -> output).
filter_shape: see with_space_to_batch filter_shape: see `with_space_to_batch`.
spatial_dims: see with_space_to_batch spatial_dims: `see with_space_to_batch`.
data_format: see with_space_to_batch data_format: see `with_space_to_batch`.
num_batch_dims: (Optional). Number of batch dims in `input_shape`.
""" """
def __init__(self, def __init__(self,
@ -561,24 +577,25 @@ class _WithSpaceToBatch(object):
build_op, build_op,
filter_shape=None, filter_shape=None,
spatial_dims=None, spatial_dims=None,
data_format=None): data_format=None,
num_batch_dims=1):
"""Helper class for _with_space_to_batch.""" """Helper class for _with_space_to_batch."""
dilation_rate = ops.convert_to_tensor( dilation_rate = ops.convert_to_tensor(
dilation_rate, dtypes.int32, name="dilation_rate") dilation_rate, dtypes.int32, name="dilation_rate")
try: if dilation_rate.shape.ndims not in (None, 1):
rate_shape = dilation_rate.get_shape().with_rank(1) raise ValueError(
except ValueError: "rate must be rank 1 but saw {}".format(dilation_rate.shape.ndims))
raise ValueError("rate must be rank 1")
if not dilation_rate.get_shape().is_fully_defined(): if not dilation_rate.shape.is_fully_defined():
raise ValueError("rate must have known shape") raise ValueError("rate must have known shape, but saw {}"
.format(dilation_rate.shape))
num_spatial_dims = rate_shape.dims[0].value num_spatial_dims = dilation_rate.shape.dims[0].value
if data_format is not None and data_format.startswith("NC"): if data_format is not None and data_format.startswith("NC"):
starting_spatial_dim = 2 starting_spatial_dim = num_batch_dims + 1
else: else:
starting_spatial_dim = 1 starting_spatial_dim = num_batch_dims
if spatial_dims is None: if spatial_dims is None:
spatial_dims = range(starting_spatial_dim, spatial_dims = range(starting_spatial_dim,
@ -588,7 +605,7 @@ class _WithSpaceToBatch(object):
if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims): if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
raise ValueError( raise ValueError(
"spatial_dims must be a monotonically increasing sequence of " "spatial_dims must be a monotonically increasing sequence of "
"positive integers") "positive integers, but saw: {}".format(orig_spatial_dims))
if data_format is not None and data_format.startswith("NC"): if data_format is not None and data_format.startswith("NC"):
expected_input_rank = spatial_dims[-1] expected_input_rank = spatial_dims[-1]
@ -599,14 +616,16 @@ class _WithSpaceToBatch(object):
input_shape.with_rank_at_least(expected_input_rank) input_shape.with_rank_at_least(expected_input_rank)
except ValueError: except ValueError:
raise ValueError( raise ValueError(
"input tensor must have rank %d at least" % (expected_input_rank)) "input tensor must have rank at least {}, but saw rank {}"
.format(expected_input_rank, input_shape.ndims))
const_rate = tensor_util.constant_value(dilation_rate) const_rate = tensor_util.constant_value(dilation_rate)
rate_or_const_rate = dilation_rate rate_or_const_rate = dilation_rate
if const_rate is not None: if const_rate is not None:
rate_or_const_rate = const_rate rate_or_const_rate = const_rate
if np.any(const_rate < 1): if np.any(const_rate < 1):
raise ValueError("dilation_rate must be positive") raise ValueError("dilation_rate must be positive, but saw: {}"
.format(const_rate))
if np.all(const_rate == 1): if np.all(const_rate == 1):
self.call = build_op(num_spatial_dims, padding) self.call = build_op(num_spatial_dims, padding)
return return
@ -672,6 +691,7 @@ class _WithSpaceToBatch(object):
filter_shape = array_ops.shape(filter) filter_shape = array_ops.shape(filter)
base_paddings = _with_space_to_batch_base_paddings( base_paddings = _with_space_to_batch_base_paddings(
filter_shape, self.num_spatial_dims, self.rate_or_const_rate) filter_shape, self.num_spatial_dims, self.rate_or_const_rate)
paddings, crops = array_ops.required_space_to_batch_paddings( paddings, crops = array_ops.required_space_to_batch_paddings(
input_shape=input_spatial_shape, input_shape=input_spatial_shape,
base_paddings=base_paddings, base_paddings=base_paddings,
@ -994,31 +1014,83 @@ def convolution_internal(
data_format=None, data_format=None,
dilations=None, dilations=None,
name=None, name=None,
call_from_convolution=True): call_from_convolution=True,
"""Internal function which performs rank agnostic convolution.""" num_spatial_dims=None):
if isinstance(input.shape, tensor_shape.TensorShape) and \ """Internal function which performs rank agnostic convolution.
input.shape.rank is not None:
n = len(input.shape) - 2 Args:
elif not isinstance(input.shape, tensor_shape.TensorShape) and \ input: See `convolution`.
input.shape is not None: filters: See `convolution`.
n = len(input.shape) - 2 strides: See `convolution`.
elif isinstance(filters.shape, tensor_shape.TensorShape) and \ padding: See `convolution`.
filters.shape.rank is not None: data_format: See `convolution`.
dilations: See `convolution`.
name: See `convolution`.
call_from_convolution: See `convolution`.
num_spatial_dims: (Optional.). It is a integer describing the
rank of the spatial dimensions. For `1-D`, `2-D` and `3-D` convolutions,
the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
This argument is only required to disambiguate the rank of `batch_shape`
when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For
backwards compatibility, if `num_spatial_dims is None` and
`filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
`1` (i.e., the input is expected to be
`[batch_size, num_channels] + input_spatial_shape`
or `[batch_size] + input_spatial_shape + [num_channels]`.
Returns:
A tensor of shape and dtype matching that of `input`.
Raises:
ValueError: If input and filter both have unknown shapes, or if
`num_spatial_dims` is provided and incompatible with the value
estimated from `filters.shape`.
"""
n = None
if isinstance(filters, (list, tuple)):
filters = np.asarray(filters)
if (isinstance(filters.shape, tensor_shape.TensorShape)
and filters.shape.rank is not None):
n = len(filters.shape) - 2 n = len(filters.shape) - 2
elif not isinstance(filters.shape, tensor_shape.TensorShape) and \ elif (not isinstance(filters.shape, tensor_shape.TensorShape)
filters.shape is not None: and filters.shape is not None):
n = len(filters.shape) - 2 n = len(filters.shape) - 2
if (isinstance(input.shape, tensor_shape.TensorShape)
and input.shape.rank is not None):
if n is None:
n = (num_spatial_dims if num_spatial_dims is not None
else len(input.shape) - 2)
num_batch_dims = len(input.shape) - n - 1
elif (not isinstance(input.shape, tensor_shape.TensorShape)
and input.shape is not None):
if n is None:
n = (num_spatial_dims if num_spatial_dims is not None
else len(input.shape) - 2)
num_batch_dims = len(input.shape) - n - 1
else: else:
num_batch_dims = 1 # Default behavior if it cannot be estimated.
if n is None:
raise ValueError("rank of input or filter must be known") raise ValueError("rank of input or filter must be known")
if num_spatial_dims is not None and n != num_spatial_dims:
raise ValueError(
"inconsistent estimate of spatial dims ({}) vs. actual passed "
"num_spatial_dims ({}). n was estimated as len(filters.shape) - 2, "
"but filters shape is: {}".format(n, num_spatial_dims, filters.shape))
if not 1 <= n <= 3: if not 1 <= n <= 3:
raise ValueError( raise ValueError(
"Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2)) "num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one "
"of 1, 2 or 3 but saw {}. num_batch_dims: {}."
.format(n, num_batch_dims))
if data_format is None: if data_format is None:
channel_index = n + 1 channel_index = num_batch_dims + n
else: else:
channel_index = 1 if data_format.startswith("NC") else n + 1 channel_index = (
num_batch_dims if data_format.startswith("NC") else n + num_batch_dims)
strides = _get_sequence(strides, n, channel_index, "strides") strides = _get_sequence(strides, n, channel_index, "strides")
dilations = _get_sequence(dilations, n, channel_index, "dilations") dilations = _get_sequence(dilations, n, channel_index, "dilations")
@ -1031,7 +1103,7 @@ def convolution_internal(
scope = "convolution" scope = "convolution"
with ops.name_scope(name, scope, [input, filters]) as name: with ops.name_scope(name, scope, [input, filters]) as name:
conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d} conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: gen_nn_ops.conv3d}
if device_context.enclosing_tpu_context() is not None or all( if device_context.enclosing_tpu_context() is not None or all(
i == 1 for i in dilations): i == 1 for i in dilations):
@ -1061,7 +1133,8 @@ def convolution_internal(
strides=strides, strides=strides,
dilation_rate=dilations, dilation_rate=dilations,
name=name, name=name,
data_format=data_format) data_format=data_format,
num_spatial_dims=n)
return op(input, filters) return op(input, filters)
@ -1069,17 +1142,34 @@ class Convolution(object):
"""Helper class for convolution. """Helper class for convolution.
Note that this class assumes that shapes of input and filter passed to Note that this class assumes that shapes of input and filter passed to
__call__ are compatible with input_shape and filter_shape passed to the `__call__` are compatible with `input_shape`, `filter_shape`, and
constructor. `num_spatial_dims` passed to the constructor.
Arguments Arguments
input_shape: static shape of input. i.e. input.get_shape(). input_shape: static shape of input. i.e. input.shape. Its length is
filter_shape: static shape of the filter. i.e. filter.get_shape(). `batch_shape + input_spatial_shape + [num_channels]` if `data_format`
padding: see convolution. does not start with `NC`, or
`batch_shape + [num_channels] + input_spatial_shape` if `data_format`
starts with `NC`.
filter_shape: static shape of the filter. i.e. filter.shape.
padding: The padding algorithm, must be "SAME" or "VALID".
strides: see convolution. strides: see convolution.
dilation_rate: see convolution. dilation_rate: see convolution.
name: see convolution. name: see convolution.
data_format: see convolution. data_format: A string or `None`. Specifies whether the channel dimension of
the `input` and output is the last dimension (if `data_format` is `None`
or does not start with `NC`), or the first post-batch dimension (i.e. if
`data_format` starts with `NC`).
num_spatial_dims: (Usually optional.) Python integer, the rank of the
spatial and channel dimensions. For `1-D`, `2-D` and `3-D` convolutions,
the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
This argument is only required to disambiguate the rank of `batch_shape`
when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For
backwards compatibility, if `num_spatial_dims is None` and
`filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
`1` (i.e., the input is expected to be
`[batch_size, num_channels] + input_spatial_shape`
or `[batch_size] + input_spatial_shape + [num_channels]`.
""" """
def __init__(self, def __init__(self,
@ -1089,40 +1179,72 @@ class Convolution(object):
strides=None, strides=None,
dilation_rate=None, dilation_rate=None,
name=None, name=None,
data_format=None): data_format=None,
num_spatial_dims=None):
"""Helper function for convolution.""" """Helper function for convolution."""
num_total_dims = filter_shape.ndims num_batch_dims = None
if num_total_dims is None: filter_shape = tensor_shape.as_shape(filter_shape)
num_total_dims = input_shape.ndims input_shape = tensor_shape.as_shape(input_shape)
if num_total_dims is None:
raise ValueError("rank of input or filter must be known")
num_spatial_dims = num_total_dims - 2 if filter_shape.ndims is not None:
if (num_spatial_dims is not None and
filter_shape.ndims != num_spatial_dims + 2):
raise ValueError(
"Expected filter_shape.ndims == num_spatial_dims + 2, "
"but saw filter_shape.ndims == {} and num_spatial_dims == {}"
.format(filter_shape.ndims, num_spatial_dims))
else:
num_spatial_dims = filter_shape.ndims - 2
try: if input_shape.ndims is not None and num_spatial_dims is not None:
input_shape.with_rank(num_spatial_dims + 2) num_batch_dims = input_shape.ndims - num_spatial_dims - 1
except ValueError:
if num_spatial_dims is None:
num_spatial_dims = input_shape.ndims - 2
else:
if input_shape.ndims is not None:
if input_shape.ndims < num_spatial_dims + 2:
raise ValueError(
"Expected input_shape.ndims >= num_spatial_dims + 2, but saw "
"input_shape.ndims == {} and num_spatial_dims == {}"
.format(input_shape.ndims, num_spatial_dims))
else:
if num_batch_dims is None:
num_batch_dims = input_shape.ndims - num_spatial_dims - 1
if num_spatial_dims is None:
raise ValueError( raise ValueError(
"input tensor must have rank %d" % (num_spatial_dims + 2)) "Cannot estimate num_spatial_dims since input_shape.ndims is None, "
"filter_shape.ndims is None, and argument num_spatial_dims is also "
"None.")
try: if num_batch_dims is None:
filter_shape.with_rank(num_spatial_dims + 2) num_batch_dims = 1
except ValueError:
if num_batch_dims < 1:
raise ValueError( raise ValueError(
"filter tensor must have rank %d" % (num_spatial_dims + 2)) "num_batch_dims should be >= 1, but saw {}. num_batch_dims was "
"estimated as `input_shape.ndims - num_spatial_dims - 1` and "
"num_spatial_dims was either provided or estimated as "
"`filter_shape.ndims - 2`. input_shape.ndims: {}, "
"num_spatial_dims: {}, filter_shape.ndims: {}"
.format(num_batch_dims, input_shape.ndims, num_spatial_dims,
filter_shape.ndims))
if data_format is None or not data_format.startswith("NC"): if data_format is None or not data_format.startswith("NC"):
input_channels_dim = tensor_shape.dimension_at_index( input_channels_dim = tensor_shape.dimension_at_index(
input_shape, num_spatial_dims + 1) input_shape, num_spatial_dims + num_batch_dims)
spatial_dims = range(1, num_spatial_dims + 1) spatial_dims = range(num_batch_dims, num_spatial_dims + num_batch_dims)
else: else:
input_channels_dim = tensor_shape.dimension_at_index(input_shape, 1) input_channels_dim = tensor_shape.dimension_at_index(
spatial_dims = range(2, num_spatial_dims + 2) input_shape, num_batch_dims)
spatial_dims = range(
num_batch_dims + 1, num_spatial_dims + num_batch_dims + 1)
if not input_channels_dim.is_compatible_with( if not input_channels_dim.is_compatible_with(
filter_shape[num_spatial_dims]): filter_shape[num_spatial_dims]):
raise ValueError( raise ValueError(
"number of input channels does not match corresponding dimension of " "Number of input channels does not match corresponding dimension of "
"filter, {} != {}".format(input_channels_dim, "filter, {} != {}".format(input_channels_dim,
filter_shape[num_spatial_dims])) filter_shape[num_spatial_dims]))
@ -1136,6 +1258,8 @@ class Convolution(object):
self.padding = padding self.padding = padding
self.name = name self.name = name
self.dilation_rate = dilation_rate self.dilation_rate = dilation_rate
self.num_batch_dims = num_batch_dims
self.num_spatial_dims = num_spatial_dims
self.conv_op = _WithSpaceToBatch( self.conv_op = _WithSpaceToBatch(
input_shape, input_shape,
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
@ -1143,7 +1267,8 @@ class Convolution(object):
build_op=self._build_op, build_op=self._build_op,
filter_shape=filter_shape, filter_shape=filter_shape,
spatial_dims=spatial_dims, spatial_dims=spatial_dims,
data_format=data_format) data_format=data_format,
num_batch_dims=num_batch_dims)
def _build_op(self, _, padding): def _build_op(self, _, padding):
return _NonAtrousConvolution( return _NonAtrousConvolution(
@ -1152,7 +1277,8 @@ class Convolution(object):
padding=padding, padding=padding,
data_format=self.data_format, data_format=self.data_format,
strides=self.strides, strides=self.strides,
name=self.name) name=self.name,
num_batch_dims=self.num_batch_dims)
def __call__(self, inp, filter): # pylint: disable=redefined-builtin def __call__(self, inp, filter): # pylint: disable=redefined-builtin
# TPU convolution supports dilations greater than 1. # TPU convolution supports dilations greater than 1.
@ -1165,7 +1291,8 @@ class Convolution(object):
data_format=self.data_format, data_format=self.data_format,
dilations=self.dilation_rate, dilations=self.dilation_rate,
name=self.name, name=self.name,
call_from_convolution=False) call_from_convolution=False,
num_spatial_dims=self.num_spatial_dims)
else: else:
return self.conv_op(inp, filter) return self.conv_op(inp, filter)
@ -2392,6 +2519,42 @@ def conv2d_transpose_v2(
name=name) name=name)
def _conv2d_expanded_batch(
input, # pylint: disable=redefined-builtin
filters,
strides,
padding,
data_format,
dilations,
name):
"""Helper function for `convolution_internal`; handles expanded batches."""
# Try really hard to avoid modifying the legacy name scopes - return early.
shape = getattr(input, "shape", None)
if shape is not None:
ndims = getattr(shape, "ndims", -1)
if ndims == -1: ndims = len(shape)
if ndims in (4, 3, 2, 1, 0, None):
return gen_nn_ops.conv2d(
input,
filter=filters,
strides=strides,
padding=padding,
data_format=data_format,
dilations=dilations,
name=name)
return _squeeze_batch_dims(
input,
functools.partial(
gen_nn_ops.conv2d,
filter=filters,
strides=strides,
padding=padding,
data_format=data_format,
dilations=dilations),
inner_rank=3,
name=name)
@tf_export("nn.atrous_conv2d_transpose") @tf_export("nn.atrous_conv2d_transpose")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def atrous_conv2d_transpose(value, def atrous_conv2d_transpose(value,