Internal change

PiperOrigin-RevId: 312712086
Change-Id: Iba2311e8ac40ebe73765f273ef48f5550d76fcfc
This commit is contained in:
A. Unique TensorFlower 2020-05-21 11:54:42 -07:00 committed by TensorFlower Gardener
parent 1d8bc7222d
commit d3cd2a76cc
2 changed files with 81 additions and 296 deletions

View File

@ -455,58 +455,6 @@ class Conv2DTest(test.TestCase):
conv1, conv1,
self.evaluate(conv2).reshape(conv1.shape)) self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConvolutionClass2DExpandedBatch(self):
tensor_in_sizes_batch = [10, 2, 3, 3]
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3]
filter_in_sizes = [1, 1, 3, 3]
filter_in = self._CreateNumpyTensor(filter_in_sizes)
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
convolver1 = nn_ops.Convolution(
input_shape=x1.shape,
filter_shape=filter_in.shape,
strides=[1, 1],
padding="VALID")
self.assertEqual(convolver1.num_batch_dims, 1)
convolver2 = nn_ops.Convolution(
input_shape=x2.shape,
filter_shape=filter_in.shape,
strides=[1, 1],
padding="VALID")
self.assertEqual(convolver2.num_batch_dims, 2)
conv1 = convolver1(x1, filter_in)
conv2 = convolver2(x2, filter_in)
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
self.assertAllEqual(
conv1,
self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes
def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self):
tensor_in_sizes_batch = [10, 2, 3, 3]
tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3]
filter_in_sizes = [1, 1, 3, 3]
filter_in = self._CreateNumpyTensor(filter_in_sizes)
x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
x2 = x1.reshape(tensor_in_sizes_expanded_batch)
conv1 = nn_ops.convolution(
x1,
filter_in,
strides=[1, 1],
padding="VALID")
conv2 = nn_ops.convolution(
x2,
filter_in,
strides=[1, 1],
padding="VALID")
self.assertEqual(conv1.shape, tensor_in_sizes_batch)
self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
self.assertAllEqual(
conv1,
self.evaluate(conv2).reshape(conv1.shape))
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testConv2D2x2Filter2x1Dilation(self): def testConv2D2x2Filter2x1Dilation(self):
self._VerifyDilatedConvValues( self._VerifyDilatedConvValues(

View File

@ -131,9 +131,9 @@ def _non_atrous_convolution(
""" """
with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope: with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope:
input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin
input_shape = input.shape input_shape = input.get_shape()
filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin
filter_shape = filter.shape filter_shape = filter.get_shape()
op = _NonAtrousConvolution( op = _NonAtrousConvolution(
input_shape, input_shape,
filter_shape=filter_shape, filter_shape=filter_shape,
@ -148,51 +148,36 @@ class _NonAtrousConvolution(object):
"""Helper class for _non_atrous_convolution. """Helper class for _non_atrous_convolution.
Note that this class assumes that shapes of input and filter passed to Note that this class assumes that shapes of input and filter passed to
`__call__` are compatible with `input_shape` and filter_shape passed to the __call__ are compatible with input_shape and filter_shape passed to the
constructor. constructor.
Arguments: Arguments:
input_shape: static input shape, i.e. input.shape. input_shape: static input shape, i.e. input.get_shape().
filter_shape: static filter shape, i.e. filter.shape. filter_shape: static filter shape, i.e. filter.get_shape().
padding: see _non_atrous_convolution. padding: see _non_atrous_convolution.
data_format: see _non_atrous_convolution. data_format: see _non_atrous_convolution.
strides: see _non_atrous_convolution. strides: see _non_atrous_convolution.
name: see _non_atrous_convolution. name: see _non_atrous_convolution.
num_batch_dims: (Optional.) The number of batch dimensions in the input;
if not provided, the default of `1` is used.
""" """
def __init__( def __init__(
self, self,
input_shape, input_shape,
filter_shape, filter_shape, # pylint: disable=redefined-builtin
padding, padding,
data_format=None, data_format=None,
strides=None, strides=None,
name=None, name=None):
num_batch_dims=1): filter_shape = filter_shape.with_rank(input_shape.ndims)
# filter shape is always rank num_spatial_dims + 2
# and num_spatial_dims == input_shape.ndims - num_batch_dims - 1
if input_shape.ndims is not None:
filter_shape = filter_shape.with_rank(
input_shape.ndims - num_batch_dims + 1)
self.padding = padding self.padding = padding
self.name = name self.name = name
# input shape is == num_spatial_dims + num_batch_dims + 1 input_shape = input_shape.with_rank(filter_shape.ndims)
# and filter_shape is always rank num_spatial_dims + 2
if filter_shape.ndims is not None:
input_shape = input_shape.with_rank(
filter_shape.ndims + num_batch_dims - 1)
if input_shape.ndims is None: if input_shape.ndims is None:
raise ValueError("Rank of convolution must be known")
if input_shape.ndims < 3 or input_shape.ndims > 5:
raise ValueError( raise ValueError(
"Rank of convolution must be known, but saw input_shape.ndims == {}" "`input` and `filter` must have rank at least 3 and at most 5")
.format(input_shape.ndims)) conv_dims = input_shape.ndims - 2
if input_shape.ndims < 3 or input_shape.ndims - num_batch_dims + 1 > 5:
raise ValueError(
"`input_shape.ndims - num_batch_dims + 1` must be at least 3 and at "
"most 5 but saw `input_shape.ndims == {}` and `num_batch_dims == {}`"
.format(input_shape.ndims, num_batch_dims))
conv_dims = input_shape.ndims - num_batch_dims - 1
if strides is None: if strides is None:
strides = [1] * conv_dims strides = [1] * conv_dims
elif len(strides) != conv_dims: elif len(strides) != conv_dims:
@ -535,7 +520,7 @@ def with_space_to_batch(
""" """
input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin
input_shape = input.shape input_shape = input.get_shape()
def build_op(num_spatial_dims, padding): def build_op(num_spatial_dims, padding):
return lambda inp, _: op(inp, num_spatial_dims, padding) return lambda inp, _: op(inp, num_spatial_dims, padding)
@ -555,19 +540,18 @@ class _WithSpaceToBatch(object):
"""Helper class for with_space_to_batch. """Helper class for with_space_to_batch.
Note that this class assumes that shapes of input and filter passed to Note that this class assumes that shapes of input and filter passed to
`__call__` are compatible with `input_shape`, `filter_shape`, and __call__ are compatible with input_shape and filter_shape passed to the
`spatial_dims` passed to the constructor. constructor.
Arguments Arguments
input_shape: static shape of input. i.e. input.shape. input_shape: static shape of input. i.e. input.get_shape().
dilation_rate: see `with_space_to_batch`. dilation_rate: see with_space_to_batch
padding: see `with_space_to_batch`. padding: see with_space_to_batch
build_op: Function that maps (num_spatial_dims, paddings) -> (function that build_op: Function that maps (num_spatial_dims, paddings) -> (function that
maps (input, filter) -> output). maps (input, filter) -> output).
filter_shape: see `with_space_to_batch`. filter_shape: see with_space_to_batch
spatial_dims: `see with_space_to_batch`. spatial_dims: see with_space_to_batch
data_format: see `with_space_to_batch`. data_format: see with_space_to_batch
num_batch_dims: (Optional). Number of batch dims in `input_shape`.
""" """
def __init__(self, def __init__(self,
@ -577,25 +561,24 @@ class _WithSpaceToBatch(object):
build_op, build_op,
filter_shape=None, filter_shape=None,
spatial_dims=None, spatial_dims=None,
data_format=None, data_format=None):
num_batch_dims=1):
"""Helper class for _with_space_to_batch.""" """Helper class for _with_space_to_batch."""
dilation_rate = ops.convert_to_tensor( dilation_rate = ops.convert_to_tensor(
dilation_rate, dtypes.int32, name="dilation_rate") dilation_rate, dtypes.int32, name="dilation_rate")
if dilation_rate.shape.ndims not in (None, 1): try:
raise ValueError( rate_shape = dilation_rate.get_shape().with_rank(1)
"rate must be rank 1 but saw {}".format(dilation_rate.shape.ndims)) except ValueError:
raise ValueError("rate must be rank 1")
if not dilation_rate.shape.is_fully_defined(): if not dilation_rate.get_shape().is_fully_defined():
raise ValueError("rate must have known shape, but saw {}" raise ValueError("rate must have known shape")
.format(dilation_rate.shape))
num_spatial_dims = dilation_rate.shape.dims[0].value num_spatial_dims = rate_shape.dims[0].value
if data_format is not None and data_format.startswith("NC"): if data_format is not None and data_format.startswith("NC"):
starting_spatial_dim = num_batch_dims + 1 starting_spatial_dim = 2
else: else:
starting_spatial_dim = num_batch_dims starting_spatial_dim = 1
if spatial_dims is None: if spatial_dims is None:
spatial_dims = range(starting_spatial_dim, spatial_dims = range(starting_spatial_dim,
@ -605,7 +588,7 @@ class _WithSpaceToBatch(object):
if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims): if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
raise ValueError( raise ValueError(
"spatial_dims must be a monotonically increasing sequence of " "spatial_dims must be a monotonically increasing sequence of "
"positive integers, but saw: {}".format(orig_spatial_dims)) "positive integers")
if data_format is not None and data_format.startswith("NC"): if data_format is not None and data_format.startswith("NC"):
expected_input_rank = spatial_dims[-1] expected_input_rank = spatial_dims[-1]
@ -616,16 +599,14 @@ class _WithSpaceToBatch(object):
input_shape.with_rank_at_least(expected_input_rank) input_shape.with_rank_at_least(expected_input_rank)
except ValueError: except ValueError:
raise ValueError( raise ValueError(
"input tensor must have rank at least {}, but saw rank {}" "input tensor must have rank %d at least" % (expected_input_rank))
.format(expected_input_rank, input_shape.ndims))
const_rate = tensor_util.constant_value(dilation_rate) const_rate = tensor_util.constant_value(dilation_rate)
rate_or_const_rate = dilation_rate rate_or_const_rate = dilation_rate
if const_rate is not None: if const_rate is not None:
rate_or_const_rate = const_rate rate_or_const_rate = const_rate
if np.any(const_rate < 1): if np.any(const_rate < 1):
raise ValueError("dilation_rate must be positive, but saw: {}" raise ValueError("dilation_rate must be positive")
.format(const_rate))
if np.all(const_rate == 1): if np.all(const_rate == 1):
self.call = build_op(num_spatial_dims, padding) self.call = build_op(num_spatial_dims, padding)
return return
@ -691,7 +672,6 @@ class _WithSpaceToBatch(object):
filter_shape = array_ops.shape(filter) filter_shape = array_ops.shape(filter)
base_paddings = _with_space_to_batch_base_paddings( base_paddings = _with_space_to_batch_base_paddings(
filter_shape, self.num_spatial_dims, self.rate_or_const_rate) filter_shape, self.num_spatial_dims, self.rate_or_const_rate)
paddings, crops = array_ops.required_space_to_batch_paddings( paddings, crops = array_ops.required_space_to_batch_paddings(
input_shape=input_spatial_shape, input_shape=input_spatial_shape,
base_paddings=base_paddings, base_paddings=base_paddings,
@ -1014,83 +994,31 @@ def convolution_internal(
data_format=None, data_format=None,
dilations=None, dilations=None,
name=None, name=None,
call_from_convolution=True, call_from_convolution=True):
num_spatial_dims=None): """Internal function which performs rank agnostic convolution."""
"""Internal function which performs rank agnostic convolution. if isinstance(input.shape, tensor_shape.TensorShape) and \
input.shape.rank is not None:
Args: n = len(input.shape) - 2
input: See `convolution`. elif not isinstance(input.shape, tensor_shape.TensorShape) and \
filters: See `convolution`. input.shape is not None:
strides: See `convolution`. n = len(input.shape) - 2
padding: See `convolution`. elif isinstance(filters.shape, tensor_shape.TensorShape) and \
data_format: See `convolution`. filters.shape.rank is not None:
dilations: See `convolution`.
name: See `convolution`.
call_from_convolution: See `convolution`.
num_spatial_dims: (Optional.). It is a integer describing the
rank of the spatial dimensions. For `1-D`, `2-D` and `3-D` convolutions,
the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
This argument is only required to disambiguate the rank of `batch_shape`
when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For
backwards compatibility, if `num_spatial_dims is None` and
`filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
`1` (i.e., the input is expected to be
`[batch_size, num_channels] + input_spatial_shape`
or `[batch_size] + input_spatial_shape + [num_channels]`.
Returns:
A tensor of shape and dtype matching that of `input`.
Raises:
ValueError: If input and filter both have unknown shapes, or if
`num_spatial_dims` is provided and incompatible with the value
estimated from `filters.shape`.
"""
n = None
if isinstance(filters, (list, tuple)):
filters = np.asarray(filters)
if (isinstance(filters.shape, tensor_shape.TensorShape)
and filters.shape.rank is not None):
n = len(filters.shape) - 2 n = len(filters.shape) - 2
elif (not isinstance(filters.shape, tensor_shape.TensorShape) elif not isinstance(filters.shape, tensor_shape.TensorShape) and \
and filters.shape is not None): filters.shape is not None:
n = len(filters.shape) - 2 n = len(filters.shape) - 2
if (isinstance(input.shape, tensor_shape.TensorShape)
and input.shape.rank is not None):
if n is None:
n = (num_spatial_dims if num_spatial_dims is not None
else len(input.shape) - 2)
num_batch_dims = len(input.shape) - n - 1
elif (not isinstance(input.shape, tensor_shape.TensorShape)
and input.shape is not None):
if n is None:
n = (num_spatial_dims if num_spatial_dims is not None
else len(input.shape) - 2)
num_batch_dims = len(input.shape) - n - 1
else: else:
num_batch_dims = 1 # Default behavior if it cannot be estimated.
if n is None:
raise ValueError("rank of input or filter must be known") raise ValueError("rank of input or filter must be known")
if num_spatial_dims is not None and n != num_spatial_dims:
raise ValueError(
"inconsistent estimate of spatial dims ({}) vs. actual passed "
"num_spatial_dims ({}). n was estimated as len(filters.shape) - 2, "
"but filters shape is: {}".format(n, num_spatial_dims, filters.shape))
if not 1 <= n <= 3: if not 1 <= n <= 3:
raise ValueError( raise ValueError(
"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one " "Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
"of 1, 2 or 3 but saw {}. num_batch_dims: {}."
.format(n, num_batch_dims))
if data_format is None: if data_format is None:
channel_index = num_batch_dims + n channel_index = n + 1
else: else:
channel_index = ( channel_index = 1 if data_format.startswith("NC") else n + 1
num_batch_dims if data_format.startswith("NC") else n + num_batch_dims)
strides = _get_sequence(strides, n, channel_index, "strides") strides = _get_sequence(strides, n, channel_index, "strides")
dilations = _get_sequence(dilations, n, channel_index, "dilations") dilations = _get_sequence(dilations, n, channel_index, "dilations")
@ -1103,7 +1031,7 @@ def convolution_internal(
scope = "convolution" scope = "convolution"
with ops.name_scope(name, scope, [input, filters]) as name: with ops.name_scope(name, scope, [input, filters]) as name:
conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: gen_nn_ops.conv3d} conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
if device_context.enclosing_tpu_context() is not None or all( if device_context.enclosing_tpu_context() is not None or all(
i == 1 for i in dilations): i == 1 for i in dilations):
@ -1133,8 +1061,7 @@ def convolution_internal(
strides=strides, strides=strides,
dilation_rate=dilations, dilation_rate=dilations,
name=name, name=name,
data_format=data_format, data_format=data_format)
num_spatial_dims=n)
return op(input, filters) return op(input, filters)
@ -1142,34 +1069,17 @@ class Convolution(object):
"""Helper class for convolution. """Helper class for convolution.
Note that this class assumes that shapes of input and filter passed to Note that this class assumes that shapes of input and filter passed to
`__call__` are compatible with `input_shape`, `filter_shape`, and __call__ are compatible with input_shape and filter_shape passed to the
`num_spatial_dims` passed to the constructor. constructor.
Arguments Arguments
input_shape: static shape of input. i.e. input.shape. Its length is input_shape: static shape of input. i.e. input.get_shape().
`batch_shape + input_spatial_shape + [num_channels]` if `data_format` filter_shape: static shape of the filter. i.e. filter.get_shape().
does not start with `NC`, or padding: see convolution.
`batch_shape + [num_channels] + input_spatial_shape` if `data_format`
starts with `NC`.
filter_shape: static shape of the filter. i.e. filter.shape.
padding: The padding algorithm, must be "SAME" or "VALID".
strides: see convolution. strides: see convolution.
dilation_rate: see convolution. dilation_rate: see convolution.
name: see convolution. name: see convolution.
data_format: A string or `None`. Specifies whether the channel dimension of data_format: see convolution.
the `input` and output is the last dimension (if `data_format` is `None`
or does not start with `NC`), or the first post-batch dimension (i.e. if
`data_format` starts with `NC`).
num_spatial_dims: (Usually optional.) Python integer, the rank of the
spatial and channel dimensions. For `1-D`, `2-D` and `3-D` convolutions,
the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
This argument is only required to disambiguate the rank of `batch_shape`
when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For
backwards compatibility, if `num_spatial_dims is None` and
`filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
`1` (i.e., the input is expected to be
`[batch_size, num_channels] + input_spatial_shape`
or `[batch_size] + input_spatial_shape + [num_channels]`.
""" """
def __init__(self, def __init__(self,
@ -1179,72 +1089,40 @@ class Convolution(object):
strides=None, strides=None,
dilation_rate=None, dilation_rate=None,
name=None, name=None,
data_format=None, data_format=None):
num_spatial_dims=None):
"""Helper function for convolution.""" """Helper function for convolution."""
num_batch_dims = None num_total_dims = filter_shape.ndims
filter_shape = tensor_shape.as_shape(filter_shape) if num_total_dims is None:
input_shape = tensor_shape.as_shape(input_shape) num_total_dims = input_shape.ndims
if num_total_dims is None:
raise ValueError("rank of input or filter must be known")
if filter_shape.ndims is not None: num_spatial_dims = num_total_dims - 2
if (num_spatial_dims is not None and
filter_shape.ndims != num_spatial_dims + 2):
raise ValueError(
"Expected filter_shape.ndims == num_spatial_dims + 2, "
"but saw filter_shape.ndims == {} and num_spatial_dims == {}"
.format(filter_shape.ndims, num_spatial_dims))
else:
num_spatial_dims = filter_shape.ndims - 2
if input_shape.ndims is not None and num_spatial_dims is not None: try:
num_batch_dims = input_shape.ndims - num_spatial_dims - 1 input_shape.with_rank(num_spatial_dims + 2)
except ValueError:
if num_spatial_dims is None:
num_spatial_dims = input_shape.ndims - 2
else:
if input_shape.ndims is not None:
if input_shape.ndims < num_spatial_dims + 2:
raise ValueError(
"Expected input_shape.ndims >= num_spatial_dims + 2, but saw "
"input_shape.ndims == {} and num_spatial_dims == {}"
.format(input_shape.ndims, num_spatial_dims))
else:
if num_batch_dims is None:
num_batch_dims = input_shape.ndims - num_spatial_dims - 1
if num_spatial_dims is None:
raise ValueError( raise ValueError(
"Cannot estimate num_spatial_dims since input_shape.ndims is None, " "input tensor must have rank %d" % (num_spatial_dims + 2))
"filter_shape.ndims is None, and argument num_spatial_dims is also "
"None.")
if num_batch_dims is None: try:
num_batch_dims = 1 filter_shape.with_rank(num_spatial_dims + 2)
except ValueError:
if num_batch_dims < 1:
raise ValueError( raise ValueError(
"num_batch_dims should be >= 1, but saw {}. num_batch_dims was " "filter tensor must have rank %d" % (num_spatial_dims + 2))
"estimated as `input_shape.ndims - num_spatial_dims - 1` and "
"num_spatial_dims was either provided or estimated as "
"`filter_shape.ndims - 2`. input_shape.ndims: {}, "
"num_spatial_dims: {}, filter_shape.ndims: {}"
.format(num_batch_dims, input_shape.ndims, num_spatial_dims,
filter_shape.ndims))
if data_format is None or not data_format.startswith("NC"): if data_format is None or not data_format.startswith("NC"):
input_channels_dim = tensor_shape.dimension_at_index( input_channels_dim = tensor_shape.dimension_at_index(
input_shape, num_spatial_dims + num_batch_dims) input_shape, num_spatial_dims + 1)
spatial_dims = range(num_batch_dims, num_spatial_dims + num_batch_dims) spatial_dims = range(1, num_spatial_dims + 1)
else: else:
input_channels_dim = tensor_shape.dimension_at_index( input_channels_dim = tensor_shape.dimension_at_index(input_shape, 1)
input_shape, num_batch_dims) spatial_dims = range(2, num_spatial_dims + 2)
spatial_dims = range(
num_batch_dims + 1, num_spatial_dims + num_batch_dims + 1)
if not input_channels_dim.is_compatible_with( if not input_channels_dim.is_compatible_with(
filter_shape[num_spatial_dims]): filter_shape[num_spatial_dims]):
raise ValueError( raise ValueError(
"Number of input channels does not match corresponding dimension of " "number of input channels does not match corresponding dimension of "
"filter, {} != {}".format(input_channels_dim, "filter, {} != {}".format(input_channels_dim,
filter_shape[num_spatial_dims])) filter_shape[num_spatial_dims]))
@ -1258,8 +1136,6 @@ class Convolution(object):
self.padding = padding self.padding = padding
self.name = name self.name = name
self.dilation_rate = dilation_rate self.dilation_rate = dilation_rate
self.num_batch_dims = num_batch_dims
self.num_spatial_dims = num_spatial_dims
self.conv_op = _WithSpaceToBatch( self.conv_op = _WithSpaceToBatch(
input_shape, input_shape,
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
@ -1267,8 +1143,7 @@ class Convolution(object):
build_op=self._build_op, build_op=self._build_op,
filter_shape=filter_shape, filter_shape=filter_shape,
spatial_dims=spatial_dims, spatial_dims=spatial_dims,
data_format=data_format, data_format=data_format)
num_batch_dims=num_batch_dims)
def _build_op(self, _, padding): def _build_op(self, _, padding):
return _NonAtrousConvolution( return _NonAtrousConvolution(
@ -1277,8 +1152,7 @@ class Convolution(object):
padding=padding, padding=padding,
data_format=self.data_format, data_format=self.data_format,
strides=self.strides, strides=self.strides,
name=self.name, name=self.name)
num_batch_dims=self.num_batch_dims)
def __call__(self, inp, filter): # pylint: disable=redefined-builtin def __call__(self, inp, filter): # pylint: disable=redefined-builtin
# TPU convolution supports dilations greater than 1. # TPU convolution supports dilations greater than 1.
@ -1291,8 +1165,7 @@ class Convolution(object):
data_format=self.data_format, data_format=self.data_format,
dilations=self.dilation_rate, dilations=self.dilation_rate,
name=self.name, name=self.name,
call_from_convolution=False, call_from_convolution=False)
num_spatial_dims=self.num_spatial_dims)
else: else:
return self.conv_op(inp, filter) return self.conv_op(inp, filter)
@ -2519,42 +2392,6 @@ def conv2d_transpose_v2(
name=name) name=name)
def _conv2d_expanded_batch(
input, # pylint: disable=redefined-builtin
filters,
strides,
padding,
data_format,
dilations,
name):
"""Helper function for `convolution_internal`; handles expanded batches."""
# Try really hard to avoid modifying the legacy name scopes - return early.
shape = getattr(input, "shape", None)
if shape is not None:
ndims = getattr(shape, "ndims", -1)
if ndims == -1: ndims = len(shape)
if ndims in (4, 3, 2, 1, 0, None):
return gen_nn_ops.conv2d(
input,
filter=filters,
strides=strides,
padding=padding,
data_format=data_format,
dilations=dilations,
name=name)
return _squeeze_batch_dims(
input,
functools.partial(
gen_nn_ops.conv2d,
filter=filters,
strides=strides,
padding=padding,
data_format=data_format,
dilations=dilations),
inner_rank=3,
name=name)
@tf_export("nn.atrous_conv2d_transpose") @tf_export("nn.atrous_conv2d_transpose")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def atrous_conv2d_transpose(value, def atrous_conv2d_transpose(value,