[TF] Add support for more than one outer batch dimension to tf.nn.convolution.
This is part 2/N of adding outer batch dimension support to tf.nn.convXd and keras.layers.ConvXd. Also added support for batch_shape.ndims > 1 to nn_ops.Convolution and other internal libraries, so that we can use this in keras.layers.ConvXD. For now, using tf.nn.convolution with filter.shape == 3 or filter.shape == 5 (conv1d or conv3d) still raises an error deep in the ops, because i haven't yet added reshape wrappers for gen_nn_ops.conv{1d,3d} but those are gonna be easy to add once this is in. I wanted to make sure it works for conv2d first. No public signature changes. Rollback of rollback with fixes. PiperOrigin-RevId: 312735044 Change-Id: I4b4497a2925a965fa45f1812d7bd25d7a2c087ac
This commit is contained in:
parent
57c5d33f89
commit
7d9d943192
tensorflow/python
@ -455,6 +455,58 @@ class Conv2DTest(test.TestCase):
|
||||
conv1,
|
||||
self.evaluate(conv2).reshape(conv1.shape))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConvolutionClass2DExpandedBatch(self):
|
||||
tensor_in_sizes_batch = [10, 2, 3, 3]
|
||||
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3]
|
||||
filter_in_sizes = [1, 1, 3, 3]
|
||||
filter_in = self._CreateNumpyTensor(filter_in_sizes)
|
||||
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
|
||||
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
|
||||
convolver1 = nn_ops.Convolution(
|
||||
input_shape=x1.shape,
|
||||
filter_shape=filter_in.shape,
|
||||
strides=[1, 1],
|
||||
padding="VALID")
|
||||
self.assertEqual(convolver1.num_batch_dims, 1)
|
||||
convolver2 = nn_ops.Convolution(
|
||||
input_shape=x2.shape,
|
||||
filter_shape=filter_in.shape,
|
||||
strides=[1, 1],
|
||||
padding="VALID")
|
||||
self.assertEqual(convolver2.num_batch_dims, 2)
|
||||
conv1 = convolver1(x1, filter_in)
|
||||
conv2 = convolver2(x2, filter_in)
|
||||
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
|
||||
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
|
||||
self.assertAllEqual(
|
||||
conv1,
|
||||
self.evaluate(conv2).reshape(conv1.shape))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self):
|
||||
tensor_in_sizes_batch = [10, 2, 3, 3]
|
||||
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3]
|
||||
filter_in_sizes = [1, 1, 3, 3]
|
||||
filter_in = self._CreateNumpyTensor(filter_in_sizes)
|
||||
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
|
||||
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
|
||||
conv1 = nn_ops.convolution(
|
||||
x1,
|
||||
filter_in,
|
||||
strides=[1, 1],
|
||||
padding="VALID")
|
||||
conv2 = nn_ops.convolution(
|
||||
x2,
|
||||
filter_in,
|
||||
strides=[1, 1],
|
||||
padding="VALID")
|
||||
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
|
||||
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
|
||||
self.assertAllEqual(
|
||||
conv1,
|
||||
self.evaluate(conv2).reshape(conv1.shape))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConv2D2x2Filter2x1Dilation(self):
|
||||
self._VerifyDilatedConvValues(
|
||||
|
@ -131,9 +131,9 @@ def _non_atrous_convolution(
|
||||
"""
|
||||
with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope:
|
||||
input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin
|
||||
input_shape = input.get_shape()
|
||||
input_shape = input.shape
|
||||
filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin
|
||||
filter_shape = filter.get_shape()
|
||||
filter_shape = filter.shape
|
||||
op = _NonAtrousConvolution(
|
||||
input_shape,
|
||||
filter_shape=filter_shape,
|
||||
@ -148,36 +148,51 @@ class _NonAtrousConvolution(object):
|
||||
"""Helper class for _non_atrous_convolution.
|
||||
|
||||
Note that this class assumes that shapes of input and filter passed to
|
||||
__call__ are compatible with input_shape and filter_shape passed to the
|
||||
`__call__` are compatible with `input_shape` and filter_shape passed to the
|
||||
constructor.
|
||||
|
||||
Arguments:
|
||||
input_shape: static input shape, i.e. input.get_shape().
|
||||
filter_shape: static filter shape, i.e. filter.get_shape().
|
||||
input_shape: static input shape, i.e. input.shape.
|
||||
filter_shape: static filter shape, i.e. filter.shape.
|
||||
padding: see _non_atrous_convolution.
|
||||
data_format: see _non_atrous_convolution.
|
||||
strides: see _non_atrous_convolution.
|
||||
name: see _non_atrous_convolution.
|
||||
num_batch_dims: (Optional.) The number of batch dimensions in the input;
|
||||
if not provided, the default of `1` is used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shape,
|
||||
filter_shape, # pylint: disable=redefined-builtin
|
||||
filter_shape,
|
||||
padding,
|
||||
data_format=None,
|
||||
strides=None,
|
||||
name=None):
|
||||
filter_shape = filter_shape.with_rank(input_shape.ndims)
|
||||
name=None,
|
||||
num_batch_dims=1):
|
||||
# filter shape is always rank num_spatial_dims + 2
|
||||
# and num_spatial_dims == input_shape.ndims - num_batch_dims - 1
|
||||
if input_shape.ndims is not None:
|
||||
filter_shape = filter_shape.with_rank(
|
||||
input_shape.ndims - num_batch_dims + 1)
|
||||
self.padding = padding
|
||||
self.name = name
|
||||
input_shape = input_shape.with_rank(filter_shape.ndims)
|
||||
# input shape is == num_spatial_dims + num_batch_dims + 1
|
||||
# and filter_shape is always rank num_spatial_dims + 2
|
||||
if filter_shape.ndims is not None:
|
||||
input_shape = input_shape.with_rank(
|
||||
filter_shape.ndims + num_batch_dims - 1)
|
||||
if input_shape.ndims is None:
|
||||
raise ValueError("Rank of convolution must be known")
|
||||
if input_shape.ndims < 3 or input_shape.ndims > 5:
|
||||
raise ValueError(
|
||||
"`input` and `filter` must have rank at least 3 and at most 5")
|
||||
conv_dims = input_shape.ndims - 2
|
||||
"Rank of convolution must be known, but saw input_shape.ndims == {}"
|
||||
.format(input_shape.ndims))
|
||||
if input_shape.ndims < 3 or input_shape.ndims - num_batch_dims + 1 > 5:
|
||||
raise ValueError(
|
||||
"`input_shape.ndims - num_batch_dims + 1` must be at least 3 and at "
|
||||
"most 5 but saw `input_shape.ndims == {}` and `num_batch_dims == {}`"
|
||||
.format(input_shape.ndims, num_batch_dims))
|
||||
conv_dims = input_shape.ndims - num_batch_dims - 1
|
||||
if strides is None:
|
||||
strides = [1] * conv_dims
|
||||
elif len(strides) != conv_dims:
|
||||
@ -520,7 +535,7 @@ def with_space_to_batch(
|
||||
|
||||
"""
|
||||
input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin
|
||||
input_shape = input.get_shape()
|
||||
input_shape = input.shape
|
||||
|
||||
def build_op(num_spatial_dims, padding):
|
||||
return lambda inp, _: op(inp, num_spatial_dims, padding)
|
||||
@ -540,18 +555,19 @@ class _WithSpaceToBatch(object):
|
||||
"""Helper class for with_space_to_batch.
|
||||
|
||||
Note that this class assumes that shapes of input and filter passed to
|
||||
__call__ are compatible with input_shape and filter_shape passed to the
|
||||
constructor.
|
||||
`__call__` are compatible with `input_shape`, `filter_shape`, and
|
||||
`spatial_dims` passed to the constructor.
|
||||
|
||||
Arguments
|
||||
input_shape: static shape of input. i.e. input.get_shape().
|
||||
dilation_rate: see with_space_to_batch
|
||||
padding: see with_space_to_batch
|
||||
input_shape: static shape of input. i.e. input.shape.
|
||||
dilation_rate: see `with_space_to_batch`.
|
||||
padding: see `with_space_to_batch`.
|
||||
build_op: Function that maps (num_spatial_dims, paddings) -> (function that
|
||||
maps (input, filter) -> output).
|
||||
filter_shape: see with_space_to_batch
|
||||
spatial_dims: see with_space_to_batch
|
||||
data_format: see with_space_to_batch
|
||||
filter_shape: see `with_space_to_batch`.
|
||||
spatial_dims: `see with_space_to_batch`.
|
||||
data_format: see `with_space_to_batch`.
|
||||
num_batch_dims: (Optional). Number of batch dims in `input_shape`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -561,24 +577,25 @@ class _WithSpaceToBatch(object):
|
||||
build_op,
|
||||
filter_shape=None,
|
||||
spatial_dims=None,
|
||||
data_format=None):
|
||||
data_format=None,
|
||||
num_batch_dims=1):
|
||||
"""Helper class for _with_space_to_batch."""
|
||||
dilation_rate = ops.convert_to_tensor(
|
||||
dilation_rate, dtypes.int32, name="dilation_rate")
|
||||
try:
|
||||
rate_shape = dilation_rate.get_shape().with_rank(1)
|
||||
except ValueError:
|
||||
raise ValueError("rate must be rank 1")
|
||||
if dilation_rate.shape.ndims not in (None, 1):
|
||||
raise ValueError(
|
||||
"rate must be rank 1 but saw {}".format(dilation_rate.shape.ndims))
|
||||
|
||||
if not dilation_rate.get_shape().is_fully_defined():
|
||||
raise ValueError("rate must have known shape")
|
||||
if not dilation_rate.shape.is_fully_defined():
|
||||
raise ValueError("rate must have known shape, but saw {}"
|
||||
.format(dilation_rate.shape))
|
||||
|
||||
num_spatial_dims = rate_shape.dims[0].value
|
||||
num_spatial_dims = dilation_rate.shape.dims[0].value
|
||||
|
||||
if data_format is not None and data_format.startswith("NC"):
|
||||
starting_spatial_dim = 2
|
||||
starting_spatial_dim = num_batch_dims + 1
|
||||
else:
|
||||
starting_spatial_dim = 1
|
||||
starting_spatial_dim = num_batch_dims
|
||||
|
||||
if spatial_dims is None:
|
||||
spatial_dims = range(starting_spatial_dim,
|
||||
@ -588,7 +605,7 @@ class _WithSpaceToBatch(object):
|
||||
if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
|
||||
raise ValueError(
|
||||
"spatial_dims must be a monotonically increasing sequence of "
|
||||
"positive integers")
|
||||
"positive integers, but saw: {}".format(orig_spatial_dims))
|
||||
|
||||
if data_format is not None and data_format.startswith("NC"):
|
||||
expected_input_rank = spatial_dims[-1]
|
||||
@ -599,14 +616,16 @@ class _WithSpaceToBatch(object):
|
||||
input_shape.with_rank_at_least(expected_input_rank)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"input tensor must have rank %d at least" % (expected_input_rank))
|
||||
"input tensor must have rank at least {}, but saw rank {}"
|
||||
.format(expected_input_rank, input_shape.ndims))
|
||||
|
||||
const_rate = tensor_util.constant_value(dilation_rate)
|
||||
rate_or_const_rate = dilation_rate
|
||||
if const_rate is not None:
|
||||
rate_or_const_rate = const_rate
|
||||
if np.any(const_rate < 1):
|
||||
raise ValueError("dilation_rate must be positive")
|
||||
raise ValueError("dilation_rate must be positive, but saw: {}"
|
||||
.format(const_rate))
|
||||
if np.all(const_rate == 1):
|
||||
self.call = build_op(num_spatial_dims, padding)
|
||||
return
|
||||
@ -672,6 +691,7 @@ class _WithSpaceToBatch(object):
|
||||
filter_shape = array_ops.shape(filter)
|
||||
base_paddings = _with_space_to_batch_base_paddings(
|
||||
filter_shape, self.num_spatial_dims, self.rate_or_const_rate)
|
||||
|
||||
paddings, crops = array_ops.required_space_to_batch_paddings(
|
||||
input_shape=input_spatial_shape,
|
||||
base_paddings=base_paddings,
|
||||
@ -994,31 +1014,84 @@ def convolution_internal(
|
||||
data_format=None,
|
||||
dilations=None,
|
||||
name=None,
|
||||
call_from_convolution=True):
|
||||
"""Internal function which performs rank agnostic convolution."""
|
||||
if isinstance(input.shape, tensor_shape.TensorShape) and \
|
||||
input.shape.rank is not None:
|
||||
n = len(input.shape) - 2
|
||||
elif not isinstance(input.shape, tensor_shape.TensorShape) and \
|
||||
input.shape is not None:
|
||||
n = len(input.shape) - 2
|
||||
elif isinstance(filters.shape, tensor_shape.TensorShape) and \
|
||||
filters.shape.rank is not None:
|
||||
call_from_convolution=True,
|
||||
num_spatial_dims=None):
|
||||
"""Internal function which performs rank agnostic convolution.
|
||||
|
||||
Args:
|
||||
input: See `convolution`.
|
||||
filters: See `convolution`.
|
||||
strides: See `convolution`.
|
||||
padding: See `convolution`.
|
||||
data_format: See `convolution`.
|
||||
dilations: See `convolution`.
|
||||
name: See `convolution`.
|
||||
call_from_convolution: See `convolution`.
|
||||
num_spatial_dims: (Optional.). It is a integer describing the
|
||||
rank of the spatial dimensions. For `1-D`, `2-D` and `3-D` convolutions,
|
||||
the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
|
||||
This argument is only required to disambiguate the rank of `batch_shape`
|
||||
when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For
|
||||
backwards compatibility, if `num_spatial_dims is None` and
|
||||
`filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
|
||||
`1` (i.e., the input is expected to be
|
||||
`[batch_size, num_channels] + input_spatial_shape`
|
||||
or `[batch_size] + input_spatial_shape + [num_channels]`.
|
||||
|
||||
Returns:
|
||||
A tensor of shape and dtype matching that of `input`.
|
||||
|
||||
Raises:
|
||||
ValueError: If input and filter both have unknown shapes, or if
|
||||
`num_spatial_dims` is provided and incompatible with the value
|
||||
estimated from `filters.shape`.
|
||||
"""
|
||||
n = None
|
||||
if getattr(filters, 'shape', None) is None:
|
||||
with ops.name_scope(name, 'convolution_internal', [filters, input]):
|
||||
filters = ops.convert_to_tensor(filters, name='filters')
|
||||
if (isinstance(filters.shape, tensor_shape.TensorShape)
|
||||
and filters.shape.rank is not None):
|
||||
n = len(filters.shape) - 2
|
||||
elif not isinstance(filters.shape, tensor_shape.TensorShape) and \
|
||||
filters.shape is not None:
|
||||
elif (not isinstance(filters.shape, tensor_shape.TensorShape)
|
||||
and filters.shape is not None):
|
||||
n = len(filters.shape) - 2
|
||||
|
||||
if (isinstance(input.shape, tensor_shape.TensorShape)
|
||||
and input.shape.rank is not None):
|
||||
if n is None:
|
||||
n = (num_spatial_dims if num_spatial_dims is not None
|
||||
else len(input.shape) - 2)
|
||||
num_batch_dims = len(input.shape) - n - 1
|
||||
elif (not isinstance(input.shape, tensor_shape.TensorShape)
|
||||
and input.shape is not None):
|
||||
if n is None:
|
||||
n = (num_spatial_dims if num_spatial_dims is not None
|
||||
else len(input.shape) - 2)
|
||||
num_batch_dims = len(input.shape) - n - 1
|
||||
else:
|
||||
num_batch_dims = 1 # Default behavior if it cannot be estimated.
|
||||
|
||||
if n is None:
|
||||
raise ValueError("rank of input or filter must be known")
|
||||
|
||||
if num_spatial_dims is not None and n != num_spatial_dims:
|
||||
raise ValueError(
|
||||
"inconsistent estimate of spatial dims ({}) vs. actual passed "
|
||||
"num_spatial_dims ({}). n was estimated as len(filters.shape) - 2, "
|
||||
"but filters shape is: {}".format(n, num_spatial_dims, filters.shape))
|
||||
|
||||
if not 1 <= n <= 3:
|
||||
raise ValueError(
|
||||
"Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
|
||||
"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one "
|
||||
"of 1, 2 or 3 but saw {}. num_batch_dims: {}."
|
||||
.format(n, num_batch_dims))
|
||||
|
||||
if data_format is None:
|
||||
channel_index = n + 1
|
||||
channel_index = num_batch_dims + n
|
||||
else:
|
||||
channel_index = 1 if data_format.startswith("NC") else n + 1
|
||||
channel_index = (
|
||||
num_batch_dims if data_format.startswith("NC") else n + num_batch_dims)
|
||||
|
||||
strides = _get_sequence(strides, n, channel_index, "strides")
|
||||
dilations = _get_sequence(dilations, n, channel_index, "dilations")
|
||||
@ -1031,7 +1104,7 @@ def convolution_internal(
|
||||
scope = "convolution"
|
||||
|
||||
with ops.name_scope(name, scope, [input, filters]) as name:
|
||||
conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
|
||||
conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: gen_nn_ops.conv3d}
|
||||
|
||||
if device_context.enclosing_tpu_context() is not None or all(
|
||||
i == 1 for i in dilations):
|
||||
@ -1061,7 +1134,8 @@ def convolution_internal(
|
||||
strides=strides,
|
||||
dilation_rate=dilations,
|
||||
name=name,
|
||||
data_format=data_format)
|
||||
data_format=data_format,
|
||||
num_spatial_dims=n)
|
||||
return op(input, filters)
|
||||
|
||||
|
||||
@ -1069,17 +1143,34 @@ class Convolution(object):
|
||||
"""Helper class for convolution.
|
||||
|
||||
Note that this class assumes that shapes of input and filter passed to
|
||||
__call__ are compatible with input_shape and filter_shape passed to the
|
||||
constructor.
|
||||
`__call__` are compatible with `input_shape`, `filter_shape`, and
|
||||
`num_spatial_dims` passed to the constructor.
|
||||
|
||||
Arguments
|
||||
input_shape: static shape of input. i.e. input.get_shape().
|
||||
filter_shape: static shape of the filter. i.e. filter.get_shape().
|
||||
padding: see convolution.
|
||||
input_shape: static shape of input. i.e. input.shape. Its length is
|
||||
`batch_shape + input_spatial_shape + [num_channels]` if `data_format`
|
||||
does not start with `NC`, or
|
||||
`batch_shape + [num_channels] + input_spatial_shape` if `data_format`
|
||||
starts with `NC`.
|
||||
filter_shape: static shape of the filter. i.e. filter.shape.
|
||||
padding: The padding algorithm, must be "SAME" or "VALID".
|
||||
strides: see convolution.
|
||||
dilation_rate: see convolution.
|
||||
name: see convolution.
|
||||
data_format: see convolution.
|
||||
data_format: A string or `None`. Specifies whether the channel dimension of
|
||||
the `input` and output is the last dimension (if `data_format` is `None`
|
||||
or does not start with `NC`), or the first post-batch dimension (i.e. if
|
||||
`data_format` starts with `NC`).
|
||||
num_spatial_dims: (Usually optional.) Python integer, the rank of the
|
||||
spatial and channel dimensions. For `1-D`, `2-D` and `3-D` convolutions,
|
||||
the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
|
||||
This argument is only required to disambiguate the rank of `batch_shape`
|
||||
when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For
|
||||
backwards compatibility, if `num_spatial_dims is None` and
|
||||
`filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
|
||||
`1` (i.e., the input is expected to be
|
||||
`[batch_size, num_channels] + input_spatial_shape`
|
||||
or `[batch_size] + input_spatial_shape + [num_channels]`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -1089,40 +1180,72 @@ class Convolution(object):
|
||||
strides=None,
|
||||
dilation_rate=None,
|
||||
name=None,
|
||||
data_format=None):
|
||||
data_format=None,
|
||||
num_spatial_dims=None):
|
||||
"""Helper function for convolution."""
|
||||
num_total_dims = filter_shape.ndims
|
||||
if num_total_dims is None:
|
||||
num_total_dims = input_shape.ndims
|
||||
if num_total_dims is None:
|
||||
raise ValueError("rank of input or filter must be known")
|
||||
num_batch_dims = None
|
||||
filter_shape = tensor_shape.as_shape(filter_shape)
|
||||
input_shape = tensor_shape.as_shape(input_shape)
|
||||
|
||||
num_spatial_dims = num_total_dims - 2
|
||||
if filter_shape.ndims is not None:
|
||||
if (num_spatial_dims is not None and
|
||||
filter_shape.ndims != num_spatial_dims + 2):
|
||||
raise ValueError(
|
||||
"Expected filter_shape.ndims == num_spatial_dims + 2, "
|
||||
"but saw filter_shape.ndims == {} and num_spatial_dims == {}"
|
||||
.format(filter_shape.ndims, num_spatial_dims))
|
||||
else:
|
||||
num_spatial_dims = filter_shape.ndims - 2
|
||||
|
||||
try:
|
||||
input_shape.with_rank(num_spatial_dims + 2)
|
||||
except ValueError:
|
||||
if input_shape.ndims is not None and num_spatial_dims is not None:
|
||||
num_batch_dims = input_shape.ndims - num_spatial_dims - 1
|
||||
|
||||
if num_spatial_dims is None:
|
||||
num_spatial_dims = input_shape.ndims - 2
|
||||
else:
|
||||
if input_shape.ndims is not None:
|
||||
if input_shape.ndims < num_spatial_dims + 2:
|
||||
raise ValueError(
|
||||
"Expected input_shape.ndims >= num_spatial_dims + 2, but saw "
|
||||
"input_shape.ndims == {} and num_spatial_dims == {}"
|
||||
.format(input_shape.ndims, num_spatial_dims))
|
||||
else:
|
||||
if num_batch_dims is None:
|
||||
num_batch_dims = input_shape.ndims - num_spatial_dims - 1
|
||||
|
||||
if num_spatial_dims is None:
|
||||
raise ValueError(
|
||||
"input tensor must have rank %d" % (num_spatial_dims + 2))
|
||||
"Cannot estimate num_spatial_dims since input_shape.ndims is None, "
|
||||
"filter_shape.ndims is None, and argument num_spatial_dims is also "
|
||||
"None.")
|
||||
|
||||
try:
|
||||
filter_shape.with_rank(num_spatial_dims + 2)
|
||||
except ValueError:
|
||||
if num_batch_dims is None:
|
||||
num_batch_dims = 1
|
||||
|
||||
if num_batch_dims < 1:
|
||||
raise ValueError(
|
||||
"filter tensor must have rank %d" % (num_spatial_dims + 2))
|
||||
"num_batch_dims should be >= 1, but saw {}. num_batch_dims was "
|
||||
"estimated as `input_shape.ndims - num_spatial_dims - 1` and "
|
||||
"num_spatial_dims was either provided or estimated as "
|
||||
"`filter_shape.ndims - 2`. input_shape.ndims: {}, "
|
||||
"num_spatial_dims: {}, filter_shape.ndims: {}"
|
||||
.format(num_batch_dims, input_shape.ndims, num_spatial_dims,
|
||||
filter_shape.ndims))
|
||||
|
||||
if data_format is None or not data_format.startswith("NC"):
|
||||
input_channels_dim = tensor_shape.dimension_at_index(
|
||||
input_shape, num_spatial_dims + 1)
|
||||
spatial_dims = range(1, num_spatial_dims + 1)
|
||||
input_shape, num_spatial_dims + num_batch_dims)
|
||||
spatial_dims = range(num_batch_dims, num_spatial_dims + num_batch_dims)
|
||||
else:
|
||||
input_channels_dim = tensor_shape.dimension_at_index(input_shape, 1)
|
||||
spatial_dims = range(2, num_spatial_dims + 2)
|
||||
input_channels_dim = tensor_shape.dimension_at_index(
|
||||
input_shape, num_batch_dims)
|
||||
spatial_dims = range(
|
||||
num_batch_dims + 1, num_spatial_dims + num_batch_dims + 1)
|
||||
|
||||
if not input_channels_dim.is_compatible_with(
|
||||
filter_shape[num_spatial_dims]):
|
||||
raise ValueError(
|
||||
"number of input channels does not match corresponding dimension of "
|
||||
"Number of input channels does not match corresponding dimension of "
|
||||
"filter, {} != {}".format(input_channels_dim,
|
||||
filter_shape[num_spatial_dims]))
|
||||
|
||||
@ -1136,6 +1259,8 @@ class Convolution(object):
|
||||
self.padding = padding
|
||||
self.name = name
|
||||
self.dilation_rate = dilation_rate
|
||||
self.num_batch_dims = num_batch_dims
|
||||
self.num_spatial_dims = num_spatial_dims
|
||||
self.conv_op = _WithSpaceToBatch(
|
||||
input_shape,
|
||||
dilation_rate=dilation_rate,
|
||||
@ -1143,7 +1268,8 @@ class Convolution(object):
|
||||
build_op=self._build_op,
|
||||
filter_shape=filter_shape,
|
||||
spatial_dims=spatial_dims,
|
||||
data_format=data_format)
|
||||
data_format=data_format,
|
||||
num_batch_dims=num_batch_dims)
|
||||
|
||||
def _build_op(self, _, padding):
|
||||
return _NonAtrousConvolution(
|
||||
@ -1152,7 +1278,8 @@ class Convolution(object):
|
||||
padding=padding,
|
||||
data_format=self.data_format,
|
||||
strides=self.strides,
|
||||
name=self.name)
|
||||
name=self.name,
|
||||
num_batch_dims=self.num_batch_dims)
|
||||
|
||||
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
|
||||
# TPU convolution supports dilations greater than 1.
|
||||
@ -1165,7 +1292,8 @@ class Convolution(object):
|
||||
data_format=self.data_format,
|
||||
dilations=self.dilation_rate,
|
||||
name=self.name,
|
||||
call_from_convolution=False)
|
||||
call_from_convolution=False,
|
||||
num_spatial_dims=self.num_spatial_dims)
|
||||
else:
|
||||
return self.conv_op(inp, filter)
|
||||
|
||||
@ -2392,6 +2520,42 @@ def conv2d_transpose_v2(
|
||||
name=name)
|
||||
|
||||
|
||||
def _conv2d_expanded_batch(
|
||||
input, # pylint: disable=redefined-builtin
|
||||
filters,
|
||||
strides,
|
||||
padding,
|
||||
data_format,
|
||||
dilations,
|
||||
name):
|
||||
"""Helper function for `convolution_internal`; handles expanded batches."""
|
||||
# Try really hard to avoid modifying the legacy name scopes - return early.
|
||||
shape = getattr(input, "shape", None)
|
||||
if shape is not None:
|
||||
ndims = getattr(shape, "ndims", -1)
|
||||
if ndims == -1: ndims = len(shape)
|
||||
if ndims in (4, 3, 2, 1, 0, None):
|
||||
return gen_nn_ops.conv2d(
|
||||
input,
|
||||
filter=filters,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
dilations=dilations,
|
||||
name=name)
|
||||
return _squeeze_batch_dims(
|
||||
input,
|
||||
functools.partial(
|
||||
gen_nn_ops.conv2d,
|
||||
filter=filters,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
dilations=dilations),
|
||||
inner_rank=3,
|
||||
name=name)
|
||||
|
||||
|
||||
@tf_export("nn.atrous_conv2d_transpose")
|
||||
@dispatch.add_dispatch_support
|
||||
def atrous_conv2d_transpose(value,
|
||||
|
Loading…
Reference in New Issue
Block a user