Internal change
PiperOrigin-RevId: 312712086 Change-Id: Iba2311e8ac40ebe73765f273ef48f5550d76fcfc
This commit is contained in:
parent
1d8bc7222d
commit
d3cd2a76cc
@ -455,58 +455,6 @@ class Conv2DTest(test.TestCase):
|
|||||||
conv1,
|
conv1,
|
||||||
self.evaluate(conv2).reshape(conv1.shape))
|
self.evaluate(conv2).reshape(conv1.shape))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testConvolutionClass2DExpandedBatch(self):
|
|
||||||
tensor_in_sizes_batch = [10, 2, 3, 3]
|
|
||||||
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3]
|
|
||||||
filter_in_sizes = [1, 1, 3, 3]
|
|
||||||
filter_in = self._CreateNumpyTensor(filter_in_sizes)
|
|
||||||
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
|
|
||||||
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
|
|
||||||
convolver1 = nn_ops.Convolution(
|
|
||||||
input_shape=x1.shape,
|
|
||||||
filter_shape=filter_in.shape,
|
|
||||||
strides=[1, 1],
|
|
||||||
padding="VALID")
|
|
||||||
self.assertEqual(convolver1.num_batch_dims, 1)
|
|
||||||
convolver2 = nn_ops.Convolution(
|
|
||||||
input_shape=x2.shape,
|
|
||||||
filter_shape=filter_in.shape,
|
|
||||||
strides=[1, 1],
|
|
||||||
padding="VALID")
|
|
||||||
self.assertEqual(convolver2.num_batch_dims, 2)
|
|
||||||
conv1 = convolver1(x1, filter_in)
|
|
||||||
conv2 = convolver2(x2, filter_in)
|
|
||||||
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
|
|
||||||
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
|
|
||||||
self.assertAllEqual(
|
|
||||||
conv1,
|
|
||||||
self.evaluate(conv2).reshape(conv1.shape))
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self):
|
|
||||||
tensor_in_sizes_batch = [10, 2, 3, 3]
|
|
||||||
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3]
|
|
||||||
filter_in_sizes = [1, 1, 3, 3]
|
|
||||||
filter_in = self._CreateNumpyTensor(filter_in_sizes)
|
|
||||||
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
|
|
||||||
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
|
|
||||||
conv1 = nn_ops.convolution(
|
|
||||||
x1,
|
|
||||||
filter_in,
|
|
||||||
strides=[1, 1],
|
|
||||||
padding="VALID")
|
|
||||||
conv2 = nn_ops.convolution(
|
|
||||||
x2,
|
|
||||||
filter_in,
|
|
||||||
strides=[1, 1],
|
|
||||||
padding="VALID")
|
|
||||||
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
|
|
||||||
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
|
|
||||||
self.assertAllEqual(
|
|
||||||
conv1,
|
|
||||||
self.evaluate(conv2).reshape(conv1.shape))
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testConv2D2x2Filter2x1Dilation(self):
|
def testConv2D2x2Filter2x1Dilation(self):
|
||||||
self._VerifyDilatedConvValues(
|
self._VerifyDilatedConvValues(
|
||||||
|
@ -131,9 +131,9 @@ def _non_atrous_convolution(
|
|||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope:
|
with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope:
|
||||||
input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin
|
input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin
|
||||||
input_shape = input.shape
|
input_shape = input.get_shape()
|
||||||
filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin
|
filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin
|
||||||
filter_shape = filter.shape
|
filter_shape = filter.get_shape()
|
||||||
op = _NonAtrousConvolution(
|
op = _NonAtrousConvolution(
|
||||||
input_shape,
|
input_shape,
|
||||||
filter_shape=filter_shape,
|
filter_shape=filter_shape,
|
||||||
@ -148,51 +148,36 @@ class _NonAtrousConvolution(object):
|
|||||||
"""Helper class for _non_atrous_convolution.
|
"""Helper class for _non_atrous_convolution.
|
||||||
|
|
||||||
Note that this class assumes that shapes of input and filter passed to
|
Note that this class assumes that shapes of input and filter passed to
|
||||||
`__call__` are compatible with `input_shape` and filter_shape passed to the
|
__call__ are compatible with input_shape and filter_shape passed to the
|
||||||
constructor.
|
constructor.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
input_shape: static input shape, i.e. input.shape.
|
input_shape: static input shape, i.e. input.get_shape().
|
||||||
filter_shape: static filter shape, i.e. filter.shape.
|
filter_shape: static filter shape, i.e. filter.get_shape().
|
||||||
padding: see _non_atrous_convolution.
|
padding: see _non_atrous_convolution.
|
||||||
data_format: see _non_atrous_convolution.
|
data_format: see _non_atrous_convolution.
|
||||||
strides: see _non_atrous_convolution.
|
strides: see _non_atrous_convolution.
|
||||||
name: see _non_atrous_convolution.
|
name: see _non_atrous_convolution.
|
||||||
num_batch_dims: (Optional.) The number of batch dimensions in the input;
|
|
||||||
if not provided, the default of `1` is used.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_shape,
|
input_shape,
|
||||||
filter_shape,
|
filter_shape, # pylint: disable=redefined-builtin
|
||||||
padding,
|
padding,
|
||||||
data_format=None,
|
data_format=None,
|
||||||
strides=None,
|
strides=None,
|
||||||
name=None,
|
name=None):
|
||||||
num_batch_dims=1):
|
filter_shape = filter_shape.with_rank(input_shape.ndims)
|
||||||
# filter shape is always rank num_spatial_dims + 2
|
|
||||||
# and num_spatial_dims == input_shape.ndims - num_batch_dims - 1
|
|
||||||
if input_shape.ndims is not None:
|
|
||||||
filter_shape = filter_shape.with_rank(
|
|
||||||
input_shape.ndims - num_batch_dims + 1)
|
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.name = name
|
self.name = name
|
||||||
# input shape is == num_spatial_dims + num_batch_dims + 1
|
input_shape = input_shape.with_rank(filter_shape.ndims)
|
||||||
# and filter_shape is always rank num_spatial_dims + 2
|
|
||||||
if filter_shape.ndims is not None:
|
|
||||||
input_shape = input_shape.with_rank(
|
|
||||||
filter_shape.ndims + num_batch_dims - 1)
|
|
||||||
if input_shape.ndims is None:
|
if input_shape.ndims is None:
|
||||||
|
raise ValueError("Rank of convolution must be known")
|
||||||
|
if input_shape.ndims < 3 or input_shape.ndims > 5:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Rank of convolution must be known, but saw input_shape.ndims == {}"
|
"`input` and `filter` must have rank at least 3 and at most 5")
|
||||||
.format(input_shape.ndims))
|
conv_dims = input_shape.ndims - 2
|
||||||
if input_shape.ndims < 3 or input_shape.ndims - num_batch_dims + 1 > 5:
|
|
||||||
raise ValueError(
|
|
||||||
"`input_shape.ndims - num_batch_dims + 1` must be at least 3 and at "
|
|
||||||
"most 5 but saw `input_shape.ndims == {}` and `num_batch_dims == {}`"
|
|
||||||
.format(input_shape.ndims, num_batch_dims))
|
|
||||||
conv_dims = input_shape.ndims - num_batch_dims - 1
|
|
||||||
if strides is None:
|
if strides is None:
|
||||||
strides = [1] * conv_dims
|
strides = [1] * conv_dims
|
||||||
elif len(strides) != conv_dims:
|
elif len(strides) != conv_dims:
|
||||||
@ -535,7 +520,7 @@ def with_space_to_batch(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin
|
input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin
|
||||||
input_shape = input.shape
|
input_shape = input.get_shape()
|
||||||
|
|
||||||
def build_op(num_spatial_dims, padding):
|
def build_op(num_spatial_dims, padding):
|
||||||
return lambda inp, _: op(inp, num_spatial_dims, padding)
|
return lambda inp, _: op(inp, num_spatial_dims, padding)
|
||||||
@ -555,19 +540,18 @@ class _WithSpaceToBatch(object):
|
|||||||
"""Helper class for with_space_to_batch.
|
"""Helper class for with_space_to_batch.
|
||||||
|
|
||||||
Note that this class assumes that shapes of input and filter passed to
|
Note that this class assumes that shapes of input and filter passed to
|
||||||
`__call__` are compatible with `input_shape`, `filter_shape`, and
|
__call__ are compatible with input_shape and filter_shape passed to the
|
||||||
`spatial_dims` passed to the constructor.
|
constructor.
|
||||||
|
|
||||||
Arguments
|
Arguments
|
||||||
input_shape: static shape of input. i.e. input.shape.
|
input_shape: static shape of input. i.e. input.get_shape().
|
||||||
dilation_rate: see `with_space_to_batch`.
|
dilation_rate: see with_space_to_batch
|
||||||
padding: see `with_space_to_batch`.
|
padding: see with_space_to_batch
|
||||||
build_op: Function that maps (num_spatial_dims, paddings) -> (function that
|
build_op: Function that maps (num_spatial_dims, paddings) -> (function that
|
||||||
maps (input, filter) -> output).
|
maps (input, filter) -> output).
|
||||||
filter_shape: see `with_space_to_batch`.
|
filter_shape: see with_space_to_batch
|
||||||
spatial_dims: `see with_space_to_batch`.
|
spatial_dims: see with_space_to_batch
|
||||||
data_format: see `with_space_to_batch`.
|
data_format: see with_space_to_batch
|
||||||
num_batch_dims: (Optional). Number of batch dims in `input_shape`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -577,25 +561,24 @@ class _WithSpaceToBatch(object):
|
|||||||
build_op,
|
build_op,
|
||||||
filter_shape=None,
|
filter_shape=None,
|
||||||
spatial_dims=None,
|
spatial_dims=None,
|
||||||
data_format=None,
|
data_format=None):
|
||||||
num_batch_dims=1):
|
|
||||||
"""Helper class for _with_space_to_batch."""
|
"""Helper class for _with_space_to_batch."""
|
||||||
dilation_rate = ops.convert_to_tensor(
|
dilation_rate = ops.convert_to_tensor(
|
||||||
dilation_rate, dtypes.int32, name="dilation_rate")
|
dilation_rate, dtypes.int32, name="dilation_rate")
|
||||||
if dilation_rate.shape.ndims not in (None, 1):
|
try:
|
||||||
raise ValueError(
|
rate_shape = dilation_rate.get_shape().with_rank(1)
|
||||||
"rate must be rank 1 but saw {}".format(dilation_rate.shape.ndims))
|
except ValueError:
|
||||||
|
raise ValueError("rate must be rank 1")
|
||||||
|
|
||||||
if not dilation_rate.shape.is_fully_defined():
|
if not dilation_rate.get_shape().is_fully_defined():
|
||||||
raise ValueError("rate must have known shape, but saw {}"
|
raise ValueError("rate must have known shape")
|
||||||
.format(dilation_rate.shape))
|
|
||||||
|
|
||||||
num_spatial_dims = dilation_rate.shape.dims[0].value
|
num_spatial_dims = rate_shape.dims[0].value
|
||||||
|
|
||||||
if data_format is not None and data_format.startswith("NC"):
|
if data_format is not None and data_format.startswith("NC"):
|
||||||
starting_spatial_dim = num_batch_dims + 1
|
starting_spatial_dim = 2
|
||||||
else:
|
else:
|
||||||
starting_spatial_dim = num_batch_dims
|
starting_spatial_dim = 1
|
||||||
|
|
||||||
if spatial_dims is None:
|
if spatial_dims is None:
|
||||||
spatial_dims = range(starting_spatial_dim,
|
spatial_dims = range(starting_spatial_dim,
|
||||||
@ -605,7 +588,7 @@ class _WithSpaceToBatch(object):
|
|||||||
if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
|
if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"spatial_dims must be a monotonically increasing sequence of "
|
"spatial_dims must be a monotonically increasing sequence of "
|
||||||
"positive integers, but saw: {}".format(orig_spatial_dims))
|
"positive integers")
|
||||||
|
|
||||||
if data_format is not None and data_format.startswith("NC"):
|
if data_format is not None and data_format.startswith("NC"):
|
||||||
expected_input_rank = spatial_dims[-1]
|
expected_input_rank = spatial_dims[-1]
|
||||||
@ -616,16 +599,14 @@ class _WithSpaceToBatch(object):
|
|||||||
input_shape.with_rank_at_least(expected_input_rank)
|
input_shape.with_rank_at_least(expected_input_rank)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"input tensor must have rank at least {}, but saw rank {}"
|
"input tensor must have rank %d at least" % (expected_input_rank))
|
||||||
.format(expected_input_rank, input_shape.ndims))
|
|
||||||
|
|
||||||
const_rate = tensor_util.constant_value(dilation_rate)
|
const_rate = tensor_util.constant_value(dilation_rate)
|
||||||
rate_or_const_rate = dilation_rate
|
rate_or_const_rate = dilation_rate
|
||||||
if const_rate is not None:
|
if const_rate is not None:
|
||||||
rate_or_const_rate = const_rate
|
rate_or_const_rate = const_rate
|
||||||
if np.any(const_rate < 1):
|
if np.any(const_rate < 1):
|
||||||
raise ValueError("dilation_rate must be positive, but saw: {}"
|
raise ValueError("dilation_rate must be positive")
|
||||||
.format(const_rate))
|
|
||||||
if np.all(const_rate == 1):
|
if np.all(const_rate == 1):
|
||||||
self.call = build_op(num_spatial_dims, padding)
|
self.call = build_op(num_spatial_dims, padding)
|
||||||
return
|
return
|
||||||
@ -691,7 +672,6 @@ class _WithSpaceToBatch(object):
|
|||||||
filter_shape = array_ops.shape(filter)
|
filter_shape = array_ops.shape(filter)
|
||||||
base_paddings = _with_space_to_batch_base_paddings(
|
base_paddings = _with_space_to_batch_base_paddings(
|
||||||
filter_shape, self.num_spatial_dims, self.rate_or_const_rate)
|
filter_shape, self.num_spatial_dims, self.rate_or_const_rate)
|
||||||
|
|
||||||
paddings, crops = array_ops.required_space_to_batch_paddings(
|
paddings, crops = array_ops.required_space_to_batch_paddings(
|
||||||
input_shape=input_spatial_shape,
|
input_shape=input_spatial_shape,
|
||||||
base_paddings=base_paddings,
|
base_paddings=base_paddings,
|
||||||
@ -1014,83 +994,31 @@ def convolution_internal(
|
|||||||
data_format=None,
|
data_format=None,
|
||||||
dilations=None,
|
dilations=None,
|
||||||
name=None,
|
name=None,
|
||||||
call_from_convolution=True,
|
call_from_convolution=True):
|
||||||
num_spatial_dims=None):
|
"""Internal function which performs rank agnostic convolution."""
|
||||||
"""Internal function which performs rank agnostic convolution.
|
if isinstance(input.shape, tensor_shape.TensorShape) and \
|
||||||
|
input.shape.rank is not None:
|
||||||
Args:
|
n = len(input.shape) - 2
|
||||||
input: See `convolution`.
|
elif not isinstance(input.shape, tensor_shape.TensorShape) and \
|
||||||
filters: See `convolution`.
|
input.shape is not None:
|
||||||
strides: See `convolution`.
|
n = len(input.shape) - 2
|
||||||
padding: See `convolution`.
|
elif isinstance(filters.shape, tensor_shape.TensorShape) and \
|
||||||
data_format: See `convolution`.
|
filters.shape.rank is not None:
|
||||||
dilations: See `convolution`.
|
|
||||||
name: See `convolution`.
|
|
||||||
call_from_convolution: See `convolution`.
|
|
||||||
num_spatial_dims: (Optional.). It is a integer describing the
|
|
||||||
rank of the spatial dimensions. For `1-D`, `2-D` and `3-D` convolutions,
|
|
||||||
the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
|
|
||||||
This argument is only required to disambiguate the rank of `batch_shape`
|
|
||||||
when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For
|
|
||||||
backwards compatibility, if `num_spatial_dims is None` and
|
|
||||||
`filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
|
|
||||||
`1` (i.e., the input is expected to be
|
|
||||||
`[batch_size, num_channels] + input_spatial_shape`
|
|
||||||
or `[batch_size] + input_spatial_shape + [num_channels]`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tensor of shape and dtype matching that of `input`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If input and filter both have unknown shapes, or if
|
|
||||||
`num_spatial_dims` is provided and incompatible with the value
|
|
||||||
estimated from `filters.shape`.
|
|
||||||
"""
|
|
||||||
n = None
|
|
||||||
if isinstance(filters, (list, tuple)):
|
|
||||||
filters = np.asarray(filters)
|
|
||||||
if (isinstance(filters.shape, tensor_shape.TensorShape)
|
|
||||||
and filters.shape.rank is not None):
|
|
||||||
n = len(filters.shape) - 2
|
n = len(filters.shape) - 2
|
||||||
elif (not isinstance(filters.shape, tensor_shape.TensorShape)
|
elif not isinstance(filters.shape, tensor_shape.TensorShape) and \
|
||||||
and filters.shape is not None):
|
filters.shape is not None:
|
||||||
n = len(filters.shape) - 2
|
n = len(filters.shape) - 2
|
||||||
|
|
||||||
if (isinstance(input.shape, tensor_shape.TensorShape)
|
|
||||||
and input.shape.rank is not None):
|
|
||||||
if n is None:
|
|
||||||
n = (num_spatial_dims if num_spatial_dims is not None
|
|
||||||
else len(input.shape) - 2)
|
|
||||||
num_batch_dims = len(input.shape) - n - 1
|
|
||||||
elif (not isinstance(input.shape, tensor_shape.TensorShape)
|
|
||||||
and input.shape is not None):
|
|
||||||
if n is None:
|
|
||||||
n = (num_spatial_dims if num_spatial_dims is not None
|
|
||||||
else len(input.shape) - 2)
|
|
||||||
num_batch_dims = len(input.shape) - n - 1
|
|
||||||
else:
|
else:
|
||||||
num_batch_dims = 1 # Default behavior if it cannot be estimated.
|
|
||||||
|
|
||||||
if n is None:
|
|
||||||
raise ValueError("rank of input or filter must be known")
|
raise ValueError("rank of input or filter must be known")
|
||||||
|
|
||||||
if num_spatial_dims is not None and n != num_spatial_dims:
|
|
||||||
raise ValueError(
|
|
||||||
"inconsistent estimate of spatial dims ({}) vs. actual passed "
|
|
||||||
"num_spatial_dims ({}). n was estimated as len(filters.shape) - 2, "
|
|
||||||
"but filters shape is: {}".format(n, num_spatial_dims, filters.shape))
|
|
||||||
|
|
||||||
if not 1 <= n <= 3:
|
if not 1 <= n <= 3:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one "
|
"Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
|
||||||
"of 1, 2 or 3 but saw {}. num_batch_dims: {}."
|
|
||||||
.format(n, num_batch_dims))
|
|
||||||
|
|
||||||
if data_format is None:
|
if data_format is None:
|
||||||
channel_index = num_batch_dims + n
|
channel_index = n + 1
|
||||||
else:
|
else:
|
||||||
channel_index = (
|
channel_index = 1 if data_format.startswith("NC") else n + 1
|
||||||
num_batch_dims if data_format.startswith("NC") else n + num_batch_dims)
|
|
||||||
|
|
||||||
strides = _get_sequence(strides, n, channel_index, "strides")
|
strides = _get_sequence(strides, n, channel_index, "strides")
|
||||||
dilations = _get_sequence(dilations, n, channel_index, "dilations")
|
dilations = _get_sequence(dilations, n, channel_index, "dilations")
|
||||||
@ -1103,7 +1031,7 @@ def convolution_internal(
|
|||||||
scope = "convolution"
|
scope = "convolution"
|
||||||
|
|
||||||
with ops.name_scope(name, scope, [input, filters]) as name:
|
with ops.name_scope(name, scope, [input, filters]) as name:
|
||||||
conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: gen_nn_ops.conv3d}
|
conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
|
||||||
|
|
||||||
if device_context.enclosing_tpu_context() is not None or all(
|
if device_context.enclosing_tpu_context() is not None or all(
|
||||||
i == 1 for i in dilations):
|
i == 1 for i in dilations):
|
||||||
@ -1133,8 +1061,7 @@ def convolution_internal(
|
|||||||
strides=strides,
|
strides=strides,
|
||||||
dilation_rate=dilations,
|
dilation_rate=dilations,
|
||||||
name=name,
|
name=name,
|
||||||
data_format=data_format,
|
data_format=data_format)
|
||||||
num_spatial_dims=n)
|
|
||||||
return op(input, filters)
|
return op(input, filters)
|
||||||
|
|
||||||
|
|
||||||
@ -1142,34 +1069,17 @@ class Convolution(object):
|
|||||||
"""Helper class for convolution.
|
"""Helper class for convolution.
|
||||||
|
|
||||||
Note that this class assumes that shapes of input and filter passed to
|
Note that this class assumes that shapes of input and filter passed to
|
||||||
`__call__` are compatible with `input_shape`, `filter_shape`, and
|
__call__ are compatible with input_shape and filter_shape passed to the
|
||||||
`num_spatial_dims` passed to the constructor.
|
constructor.
|
||||||
|
|
||||||
Arguments
|
Arguments
|
||||||
input_shape: static shape of input. i.e. input.shape. Its length is
|
input_shape: static shape of input. i.e. input.get_shape().
|
||||||
`batch_shape + input_spatial_shape + [num_channels]` if `data_format`
|
filter_shape: static shape of the filter. i.e. filter.get_shape().
|
||||||
does not start with `NC`, or
|
padding: see convolution.
|
||||||
`batch_shape + [num_channels] + input_spatial_shape` if `data_format`
|
|
||||||
starts with `NC`.
|
|
||||||
filter_shape: static shape of the filter. i.e. filter.shape.
|
|
||||||
padding: The padding algorithm, must be "SAME" or "VALID".
|
|
||||||
strides: see convolution.
|
strides: see convolution.
|
||||||
dilation_rate: see convolution.
|
dilation_rate: see convolution.
|
||||||
name: see convolution.
|
name: see convolution.
|
||||||
data_format: A string or `None`. Specifies whether the channel dimension of
|
data_format: see convolution.
|
||||||
the `input` and output is the last dimension (if `data_format` is `None`
|
|
||||||
or does not start with `NC`), or the first post-batch dimension (i.e. if
|
|
||||||
`data_format` starts with `NC`).
|
|
||||||
num_spatial_dims: (Usually optional.) Python integer, the rank of the
|
|
||||||
spatial and channel dimensions. For `1-D`, `2-D` and `3-D` convolutions,
|
|
||||||
the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
|
|
||||||
This argument is only required to disambiguate the rank of `batch_shape`
|
|
||||||
when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For
|
|
||||||
backwards compatibility, if `num_spatial_dims is None` and
|
|
||||||
`filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
|
|
||||||
`1` (i.e., the input is expected to be
|
|
||||||
`[batch_size, num_channels] + input_spatial_shape`
|
|
||||||
or `[batch_size] + input_spatial_shape + [num_channels]`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -1179,72 +1089,40 @@ class Convolution(object):
|
|||||||
strides=None,
|
strides=None,
|
||||||
dilation_rate=None,
|
dilation_rate=None,
|
||||||
name=None,
|
name=None,
|
||||||
data_format=None,
|
data_format=None):
|
||||||
num_spatial_dims=None):
|
|
||||||
"""Helper function for convolution."""
|
"""Helper function for convolution."""
|
||||||
num_batch_dims = None
|
num_total_dims = filter_shape.ndims
|
||||||
filter_shape = tensor_shape.as_shape(filter_shape)
|
if num_total_dims is None:
|
||||||
input_shape = tensor_shape.as_shape(input_shape)
|
num_total_dims = input_shape.ndims
|
||||||
|
if num_total_dims is None:
|
||||||
|
raise ValueError("rank of input or filter must be known")
|
||||||
|
|
||||||
if filter_shape.ndims is not None:
|
num_spatial_dims = num_total_dims - 2
|
||||||
if (num_spatial_dims is not None and
|
|
||||||
filter_shape.ndims != num_spatial_dims + 2):
|
|
||||||
raise ValueError(
|
|
||||||
"Expected filter_shape.ndims == num_spatial_dims + 2, "
|
|
||||||
"but saw filter_shape.ndims == {} and num_spatial_dims == {}"
|
|
||||||
.format(filter_shape.ndims, num_spatial_dims))
|
|
||||||
else:
|
|
||||||
num_spatial_dims = filter_shape.ndims - 2
|
|
||||||
|
|
||||||
if input_shape.ndims is not None and num_spatial_dims is not None:
|
try:
|
||||||
num_batch_dims = input_shape.ndims - num_spatial_dims - 1
|
input_shape.with_rank(num_spatial_dims + 2)
|
||||||
|
except ValueError:
|
||||||
if num_spatial_dims is None:
|
|
||||||
num_spatial_dims = input_shape.ndims - 2
|
|
||||||
else:
|
|
||||||
if input_shape.ndims is not None:
|
|
||||||
if input_shape.ndims < num_spatial_dims + 2:
|
|
||||||
raise ValueError(
|
|
||||||
"Expected input_shape.ndims >= num_spatial_dims + 2, but saw "
|
|
||||||
"input_shape.ndims == {} and num_spatial_dims == {}"
|
|
||||||
.format(input_shape.ndims, num_spatial_dims))
|
|
||||||
else:
|
|
||||||
if num_batch_dims is None:
|
|
||||||
num_batch_dims = input_shape.ndims - num_spatial_dims - 1
|
|
||||||
|
|
||||||
if num_spatial_dims is None:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot estimate num_spatial_dims since input_shape.ndims is None, "
|
"input tensor must have rank %d" % (num_spatial_dims + 2))
|
||||||
"filter_shape.ndims is None, and argument num_spatial_dims is also "
|
|
||||||
"None.")
|
|
||||||
|
|
||||||
if num_batch_dims is None:
|
try:
|
||||||
num_batch_dims = 1
|
filter_shape.with_rank(num_spatial_dims + 2)
|
||||||
|
except ValueError:
|
||||||
if num_batch_dims < 1:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"num_batch_dims should be >= 1, but saw {}. num_batch_dims was "
|
"filter tensor must have rank %d" % (num_spatial_dims + 2))
|
||||||
"estimated as `input_shape.ndims - num_spatial_dims - 1` and "
|
|
||||||
"num_spatial_dims was either provided or estimated as "
|
|
||||||
"`filter_shape.ndims - 2`. input_shape.ndims: {}, "
|
|
||||||
"num_spatial_dims: {}, filter_shape.ndims: {}"
|
|
||||||
.format(num_batch_dims, input_shape.ndims, num_spatial_dims,
|
|
||||||
filter_shape.ndims))
|
|
||||||
|
|
||||||
if data_format is None or not data_format.startswith("NC"):
|
if data_format is None or not data_format.startswith("NC"):
|
||||||
input_channels_dim = tensor_shape.dimension_at_index(
|
input_channels_dim = tensor_shape.dimension_at_index(
|
||||||
input_shape, num_spatial_dims + num_batch_dims)
|
input_shape, num_spatial_dims + 1)
|
||||||
spatial_dims = range(num_batch_dims, num_spatial_dims + num_batch_dims)
|
spatial_dims = range(1, num_spatial_dims + 1)
|
||||||
else:
|
else:
|
||||||
input_channels_dim = tensor_shape.dimension_at_index(
|
input_channels_dim = tensor_shape.dimension_at_index(input_shape, 1)
|
||||||
input_shape, num_batch_dims)
|
spatial_dims = range(2, num_spatial_dims + 2)
|
||||||
spatial_dims = range(
|
|
||||||
num_batch_dims + 1, num_spatial_dims + num_batch_dims + 1)
|
|
||||||
|
|
||||||
if not input_channels_dim.is_compatible_with(
|
if not input_channels_dim.is_compatible_with(
|
||||||
filter_shape[num_spatial_dims]):
|
filter_shape[num_spatial_dims]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Number of input channels does not match corresponding dimension of "
|
"number of input channels does not match corresponding dimension of "
|
||||||
"filter, {} != {}".format(input_channels_dim,
|
"filter, {} != {}".format(input_channels_dim,
|
||||||
filter_shape[num_spatial_dims]))
|
filter_shape[num_spatial_dims]))
|
||||||
|
|
||||||
@ -1258,8 +1136,6 @@ class Convolution(object):
|
|||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.name = name
|
self.name = name
|
||||||
self.dilation_rate = dilation_rate
|
self.dilation_rate = dilation_rate
|
||||||
self.num_batch_dims = num_batch_dims
|
|
||||||
self.num_spatial_dims = num_spatial_dims
|
|
||||||
self.conv_op = _WithSpaceToBatch(
|
self.conv_op = _WithSpaceToBatch(
|
||||||
input_shape,
|
input_shape,
|
||||||
dilation_rate=dilation_rate,
|
dilation_rate=dilation_rate,
|
||||||
@ -1267,8 +1143,7 @@ class Convolution(object):
|
|||||||
build_op=self._build_op,
|
build_op=self._build_op,
|
||||||
filter_shape=filter_shape,
|
filter_shape=filter_shape,
|
||||||
spatial_dims=spatial_dims,
|
spatial_dims=spatial_dims,
|
||||||
data_format=data_format,
|
data_format=data_format)
|
||||||
num_batch_dims=num_batch_dims)
|
|
||||||
|
|
||||||
def _build_op(self, _, padding):
|
def _build_op(self, _, padding):
|
||||||
return _NonAtrousConvolution(
|
return _NonAtrousConvolution(
|
||||||
@ -1277,8 +1152,7 @@ class Convolution(object):
|
|||||||
padding=padding,
|
padding=padding,
|
||||||
data_format=self.data_format,
|
data_format=self.data_format,
|
||||||
strides=self.strides,
|
strides=self.strides,
|
||||||
name=self.name,
|
name=self.name)
|
||||||
num_batch_dims=self.num_batch_dims)
|
|
||||||
|
|
||||||
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
|
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
|
||||||
# TPU convolution supports dilations greater than 1.
|
# TPU convolution supports dilations greater than 1.
|
||||||
@ -1291,8 +1165,7 @@ class Convolution(object):
|
|||||||
data_format=self.data_format,
|
data_format=self.data_format,
|
||||||
dilations=self.dilation_rate,
|
dilations=self.dilation_rate,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
call_from_convolution=False,
|
call_from_convolution=False)
|
||||||
num_spatial_dims=self.num_spatial_dims)
|
|
||||||
else:
|
else:
|
||||||
return self.conv_op(inp, filter)
|
return self.conv_op(inp, filter)
|
||||||
|
|
||||||
@ -2519,42 +2392,6 @@ def conv2d_transpose_v2(
|
|||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
def _conv2d_expanded_batch(
|
|
||||||
input, # pylint: disable=redefined-builtin
|
|
||||||
filters,
|
|
||||||
strides,
|
|
||||||
padding,
|
|
||||||
data_format,
|
|
||||||
dilations,
|
|
||||||
name):
|
|
||||||
"""Helper function for `convolution_internal`; handles expanded batches."""
|
|
||||||
# Try really hard to avoid modifying the legacy name scopes - return early.
|
|
||||||
shape = getattr(input, "shape", None)
|
|
||||||
if shape is not None:
|
|
||||||
ndims = getattr(shape, "ndims", -1)
|
|
||||||
if ndims == -1: ndims = len(shape)
|
|
||||||
if ndims in (4, 3, 2, 1, 0, None):
|
|
||||||
return gen_nn_ops.conv2d(
|
|
||||||
input,
|
|
||||||
filter=filters,
|
|
||||||
strides=strides,
|
|
||||||
padding=padding,
|
|
||||||
data_format=data_format,
|
|
||||||
dilations=dilations,
|
|
||||||
name=name)
|
|
||||||
return _squeeze_batch_dims(
|
|
||||||
input,
|
|
||||||
functools.partial(
|
|
||||||
gen_nn_ops.conv2d,
|
|
||||||
filter=filters,
|
|
||||||
strides=strides,
|
|
||||||
padding=padding,
|
|
||||||
data_format=data_format,
|
|
||||||
dilations=dilations),
|
|
||||||
inner_rank=3,
|
|
||||||
name=name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("nn.atrous_conv2d_transpose")
|
@tf_export("nn.atrous_conv2d_transpose")
|
||||||
@dispatch.add_dispatch_support
|
@dispatch.add_dispatch_support
|
||||||
def atrous_conv2d_transpose(value,
|
def atrous_conv2d_transpose(value,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user