diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 18b7a47fc8c..e01abc8133d 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -455,6 +455,58 @@ class Conv2DTest(test.TestCase): conv1, self.evaluate(conv2).reshape(conv1.shape)) + @test_util.run_in_graph_and_eager_modes + def testConvolutionClass2DExpandedBatch(self): + tensor_in_sizes_batch = [10, 2, 3, 3] + tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3] + filter_in_sizes = [1, 1, 3, 3] + filter_in = self._CreateNumpyTensor(filter_in_sizes) + x1 = self._CreateNumpyTensor(tensor_in_sizes_batch) + x2 = x1.reshape(tensor_in_sizes_expanded_batch) + convolver1 = nn_ops.Convolution( + input_shape=x1.shape, + filter_shape=filter_in.shape, + strides=[1, 1], + padding="VALID") + self.assertEqual(convolver1.num_batch_dims, 1) + convolver2 = nn_ops.Convolution( + input_shape=x2.shape, + filter_shape=filter_in.shape, + strides=[1, 1], + padding="VALID") + self.assertEqual(convolver2.num_batch_dims, 2) + conv1 = convolver1(x1, filter_in) + conv2 = convolver2(x2, filter_in) + self.assertEqual(conv1.shape, tensor_in_sizes_batch) + self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch) + self.assertAllEqual( + conv1, + self.evaluate(conv2).reshape(conv1.shape)) + + @test_util.run_in_graph_and_eager_modes + def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self): + tensor_in_sizes_batch = [10, 2, 3, 3] + tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3] + filter_in_sizes = [1, 1, 3, 3] + filter_in = self._CreateNumpyTensor(filter_in_sizes) + x1 = self._CreateNumpyTensor(tensor_in_sizes_batch) + x2 = x1.reshape(tensor_in_sizes_expanded_batch) + conv1 = nn_ops.convolution( + x1, + filter_in, + strides=[1, 1], + padding="VALID") + conv2 = nn_ops.convolution( + x2, + filter_in, + strides=[1, 1], + padding="VALID") + self.assertEqual(conv1.shape, tensor_in_sizes_batch) + self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch) + self.assertAllEqual( + conv1, + self.evaluate(conv2).reshape(conv1.shape)) + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Filter2x1Dilation(self): self._VerifyDilatedConvValues( diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 4c00d085f82..24ee94fac48 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -131,9 +131,9 @@ def _non_atrous_convolution( """ with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope: input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin - input_shape = input.get_shape() + input_shape = input.shape filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin - filter_shape = filter.get_shape() + filter_shape = filter.shape op = _NonAtrousConvolution( input_shape, filter_shape=filter_shape, @@ -148,36 +148,51 @@ class _NonAtrousConvolution(object): """Helper class for _non_atrous_convolution. Note that this class assumes that shapes of input and filter passed to - __call__ are compatible with input_shape and filter_shape passed to the + `__call__` are compatible with `input_shape` and filter_shape passed to the constructor. Arguments: - input_shape: static input shape, i.e. input.get_shape(). - filter_shape: static filter shape, i.e. filter.get_shape(). + input_shape: static input shape, i.e. input.shape. + filter_shape: static filter shape, i.e. filter.shape. padding: see _non_atrous_convolution. data_format: see _non_atrous_convolution. strides: see _non_atrous_convolution. name: see _non_atrous_convolution. + num_batch_dims: (Optional.) The number of batch dimensions in the input; + if not provided, the default of `1` is used. """ def __init__( self, input_shape, - filter_shape, # pylint: disable=redefined-builtin + filter_shape, padding, data_format=None, strides=None, - name=None): - filter_shape = filter_shape.with_rank(input_shape.ndims) + name=None, + num_batch_dims=1): + # filter shape is always rank num_spatial_dims + 2 + # and num_spatial_dims == input_shape.ndims - num_batch_dims - 1 + if input_shape.ndims is not None: + filter_shape = filter_shape.with_rank( + input_shape.ndims - num_batch_dims + 1) self.padding = padding self.name = name - input_shape = input_shape.with_rank(filter_shape.ndims) + # input shape is == num_spatial_dims + num_batch_dims + 1 + # and filter_shape is always rank num_spatial_dims + 2 + if filter_shape.ndims is not None: + input_shape = input_shape.with_rank( + filter_shape.ndims + num_batch_dims - 1) if input_shape.ndims is None: - raise ValueError("Rank of convolution must be known") - if input_shape.ndims < 3 or input_shape.ndims > 5: raise ValueError( - "`input` and `filter` must have rank at least 3 and at most 5") - conv_dims = input_shape.ndims - 2 + "Rank of convolution must be known, but saw input_shape.ndims == {}" + .format(input_shape.ndims)) + if input_shape.ndims < 3 or input_shape.ndims - num_batch_dims + 1 > 5: + raise ValueError( + "`input_shape.ndims - num_batch_dims + 1` must be at least 3 and at " + "most 5 but saw `input_shape.ndims == {}` and `num_batch_dims == {}`" + .format(input_shape.ndims, num_batch_dims)) + conv_dims = input_shape.ndims - num_batch_dims - 1 if strides is None: strides = [1] * conv_dims elif len(strides) != conv_dims: @@ -520,7 +535,7 @@ def with_space_to_batch( """ input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin - input_shape = input.get_shape() + input_shape = input.shape def build_op(num_spatial_dims, padding): return lambda inp, _: op(inp, num_spatial_dims, padding) @@ -540,18 +555,19 @@ class _WithSpaceToBatch(object): """Helper class for with_space_to_batch. Note that this class assumes that shapes of input and filter passed to - __call__ are compatible with input_shape and filter_shape passed to the - constructor. + `__call__` are compatible with `input_shape`, `filter_shape`, and + `spatial_dims` passed to the constructor. Arguments - input_shape: static shape of input. i.e. input.get_shape(). - dilation_rate: see with_space_to_batch - padding: see with_space_to_batch + input_shape: static shape of input. i.e. input.shape. + dilation_rate: see `with_space_to_batch`. + padding: see `with_space_to_batch`. build_op: Function that maps (num_spatial_dims, paddings) -> (function that maps (input, filter) -> output). - filter_shape: see with_space_to_batch - spatial_dims: see with_space_to_batch - data_format: see with_space_to_batch + filter_shape: see `with_space_to_batch`. + spatial_dims: `see with_space_to_batch`. + data_format: see `with_space_to_batch`. + num_batch_dims: (Optional). Number of batch dims in `input_shape`. """ def __init__(self, @@ -561,24 +577,25 @@ class _WithSpaceToBatch(object): build_op, filter_shape=None, spatial_dims=None, - data_format=None): + data_format=None, + num_batch_dims=1): """Helper class for _with_space_to_batch.""" dilation_rate = ops.convert_to_tensor( dilation_rate, dtypes.int32, name="dilation_rate") - try: - rate_shape = dilation_rate.get_shape().with_rank(1) - except ValueError: - raise ValueError("rate must be rank 1") + if dilation_rate.shape.ndims not in (None, 1): + raise ValueError( + "rate must be rank 1 but saw {}".format(dilation_rate.shape.ndims)) - if not dilation_rate.get_shape().is_fully_defined(): - raise ValueError("rate must have known shape") + if not dilation_rate.shape.is_fully_defined(): + raise ValueError("rate must have known shape, but saw {}" + .format(dilation_rate.shape)) - num_spatial_dims = rate_shape.dims[0].value + num_spatial_dims = dilation_rate.shape.dims[0].value if data_format is not None and data_format.startswith("NC"): - starting_spatial_dim = 2 + starting_spatial_dim = num_batch_dims + 1 else: - starting_spatial_dim = 1 + starting_spatial_dim = num_batch_dims if spatial_dims is None: spatial_dims = range(starting_spatial_dim, @@ -588,7 +605,7 @@ class _WithSpaceToBatch(object): if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims): raise ValueError( "spatial_dims must be a monotonically increasing sequence of " - "positive integers") + "positive integers, but saw: {}".format(orig_spatial_dims)) if data_format is not None and data_format.startswith("NC"): expected_input_rank = spatial_dims[-1] @@ -599,14 +616,16 @@ class _WithSpaceToBatch(object): input_shape.with_rank_at_least(expected_input_rank) except ValueError: raise ValueError( - "input tensor must have rank %d at least" % (expected_input_rank)) + "input tensor must have rank at least {}, but saw rank {}" + .format(expected_input_rank, input_shape.ndims)) const_rate = tensor_util.constant_value(dilation_rate) rate_or_const_rate = dilation_rate if const_rate is not None: rate_or_const_rate = const_rate if np.any(const_rate < 1): - raise ValueError("dilation_rate must be positive") + raise ValueError("dilation_rate must be positive, but saw: {}" + .format(const_rate)) if np.all(const_rate == 1): self.call = build_op(num_spatial_dims, padding) return @@ -672,6 +691,7 @@ class _WithSpaceToBatch(object): filter_shape = array_ops.shape(filter) base_paddings = _with_space_to_batch_base_paddings( filter_shape, self.num_spatial_dims, self.rate_or_const_rate) + paddings, crops = array_ops.required_space_to_batch_paddings( input_shape=input_spatial_shape, base_paddings=base_paddings, @@ -994,31 +1014,84 @@ def convolution_internal( data_format=None, dilations=None, name=None, - call_from_convolution=True): - """Internal function which performs rank agnostic convolution.""" - if isinstance(input.shape, tensor_shape.TensorShape) and \ - input.shape.rank is not None: - n = len(input.shape) - 2 - elif not isinstance(input.shape, tensor_shape.TensorShape) and \ - input.shape is not None: - n = len(input.shape) - 2 - elif isinstance(filters.shape, tensor_shape.TensorShape) and \ - filters.shape.rank is not None: + call_from_convolution=True, + num_spatial_dims=None): + """Internal function which performs rank agnostic convolution. + + Args: + input: See `convolution`. + filters: See `convolution`. + strides: See `convolution`. + padding: See `convolution`. + data_format: See `convolution`. + dilations: See `convolution`. + name: See `convolution`. + call_from_convolution: See `convolution`. + num_spatial_dims: (Optional.). It is a integer describing the + rank of the spatial dimensions. For `1-D`, `2-D` and `3-D` convolutions, + the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively. + This argument is only required to disambiguate the rank of `batch_shape` + when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For + backwards compatibility, if `num_spatial_dims is None` and + `filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be + `1` (i.e., the input is expected to be + `[batch_size, num_channels] + input_spatial_shape` + or `[batch_size] + input_spatial_shape + [num_channels]`. + + Returns: + A tensor of shape and dtype matching that of `input`. + + Raises: + ValueError: If input and filter both have unknown shapes, or if + `num_spatial_dims` is provided and incompatible with the value + estimated from `filters.shape`. + """ + n = None + if getattr(filters, 'shape', None) is None: + with ops.name_scope(name, 'convolution_internal', [filters, input]): + filters = ops.convert_to_tensor(filters, name='filters') + if (isinstance(filters.shape, tensor_shape.TensorShape) + and filters.shape.rank is not None): n = len(filters.shape) - 2 - elif not isinstance(filters.shape, tensor_shape.TensorShape) and \ - filters.shape is not None: + elif (not isinstance(filters.shape, tensor_shape.TensorShape) + and filters.shape is not None): n = len(filters.shape) - 2 + + if (isinstance(input.shape, tensor_shape.TensorShape) + and input.shape.rank is not None): + if n is None: + n = (num_spatial_dims if num_spatial_dims is not None + else len(input.shape) - 2) + num_batch_dims = len(input.shape) - n - 1 + elif (not isinstance(input.shape, tensor_shape.TensorShape) + and input.shape is not None): + if n is None: + n = (num_spatial_dims if num_spatial_dims is not None + else len(input.shape) - 2) + num_batch_dims = len(input.shape) - n - 1 else: + num_batch_dims = 1 # Default behavior if it cannot be estimated. + + if n is None: raise ValueError("rank of input or filter must be known") + if num_spatial_dims is not None and n != num_spatial_dims: + raise ValueError( + "inconsistent estimate of spatial dims ({}) vs. actual passed " + "num_spatial_dims ({}). n was estimated as len(filters.shape) - 2, " + "but filters shape is: {}".format(n, num_spatial_dims, filters.shape)) + if not 1 <= n <= 3: raise ValueError( - "Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2)) + "num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one " + "of 1, 2 or 3 but saw {}. num_batch_dims: {}." + .format(n, num_batch_dims)) if data_format is None: - channel_index = n + 1 + channel_index = num_batch_dims + n else: - channel_index = 1 if data_format.startswith("NC") else n + 1 + channel_index = ( + num_batch_dims if data_format.startswith("NC") else n + num_batch_dims) strides = _get_sequence(strides, n, channel_index, "strides") dilations = _get_sequence(dilations, n, channel_index, "dilations") @@ -1031,7 +1104,7 @@ def convolution_internal( scope = "convolution" with ops.name_scope(name, scope, [input, filters]) as name: - conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d} + conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: gen_nn_ops.conv3d} if device_context.enclosing_tpu_context() is not None or all( i == 1 for i in dilations): @@ -1061,7 +1134,8 @@ def convolution_internal( strides=strides, dilation_rate=dilations, name=name, - data_format=data_format) + data_format=data_format, + num_spatial_dims=n) return op(input, filters) @@ -1069,17 +1143,34 @@ class Convolution(object): """Helper class for convolution. Note that this class assumes that shapes of input and filter passed to - __call__ are compatible with input_shape and filter_shape passed to the - constructor. + `__call__` are compatible with `input_shape`, `filter_shape`, and + `num_spatial_dims` passed to the constructor. Arguments - input_shape: static shape of input. i.e. input.get_shape(). - filter_shape: static shape of the filter. i.e. filter.get_shape(). - padding: see convolution. + input_shape: static shape of input. i.e. input.shape. Its length is + `batch_shape + input_spatial_shape + [num_channels]` if `data_format` + does not start with `NC`, or + `batch_shape + [num_channels] + input_spatial_shape` if `data_format` + starts with `NC`. + filter_shape: static shape of the filter. i.e. filter.shape. + padding: The padding algorithm, must be "SAME" or "VALID". strides: see convolution. dilation_rate: see convolution. name: see convolution. - data_format: see convolution. + data_format: A string or `None`. Specifies whether the channel dimension of + the `input` and output is the last dimension (if `data_format` is `None` + or does not start with `NC`), or the first post-batch dimension (i.e. if + `data_format` starts with `NC`). + num_spatial_dims: (Usually optional.) Python integer, the rank of the + spatial and channel dimensions. For `1-D`, `2-D` and `3-D` convolutions, + the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively. + This argument is only required to disambiguate the rank of `batch_shape` + when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For + backwards compatibility, if `num_spatial_dims is None` and + `filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be + `1` (i.e., the input is expected to be + `[batch_size, num_channels] + input_spatial_shape` + or `[batch_size] + input_spatial_shape + [num_channels]`. """ def __init__(self, @@ -1089,40 +1180,72 @@ class Convolution(object): strides=None, dilation_rate=None, name=None, - data_format=None): + data_format=None, + num_spatial_dims=None): """Helper function for convolution.""" - num_total_dims = filter_shape.ndims - if num_total_dims is None: - num_total_dims = input_shape.ndims - if num_total_dims is None: - raise ValueError("rank of input or filter must be known") + num_batch_dims = None + filter_shape = tensor_shape.as_shape(filter_shape) + input_shape = tensor_shape.as_shape(input_shape) - num_spatial_dims = num_total_dims - 2 + if filter_shape.ndims is not None: + if (num_spatial_dims is not None and + filter_shape.ndims != num_spatial_dims + 2): + raise ValueError( + "Expected filter_shape.ndims == num_spatial_dims + 2, " + "but saw filter_shape.ndims == {} and num_spatial_dims == {}" + .format(filter_shape.ndims, num_spatial_dims)) + else: + num_spatial_dims = filter_shape.ndims - 2 - try: - input_shape.with_rank(num_spatial_dims + 2) - except ValueError: + if input_shape.ndims is not None and num_spatial_dims is not None: + num_batch_dims = input_shape.ndims - num_spatial_dims - 1 + + if num_spatial_dims is None: + num_spatial_dims = input_shape.ndims - 2 + else: + if input_shape.ndims is not None: + if input_shape.ndims < num_spatial_dims + 2: + raise ValueError( + "Expected input_shape.ndims >= num_spatial_dims + 2, but saw " + "input_shape.ndims == {} and num_spatial_dims == {}" + .format(input_shape.ndims, num_spatial_dims)) + else: + if num_batch_dims is None: + num_batch_dims = input_shape.ndims - num_spatial_dims - 1 + + if num_spatial_dims is None: raise ValueError( - "input tensor must have rank %d" % (num_spatial_dims + 2)) + "Cannot estimate num_spatial_dims since input_shape.ndims is None, " + "filter_shape.ndims is None, and argument num_spatial_dims is also " + "None.") - try: - filter_shape.with_rank(num_spatial_dims + 2) - except ValueError: + if num_batch_dims is None: + num_batch_dims = 1 + + if num_batch_dims < 1: raise ValueError( - "filter tensor must have rank %d" % (num_spatial_dims + 2)) + "num_batch_dims should be >= 1, but saw {}. num_batch_dims was " + "estimated as `input_shape.ndims - num_spatial_dims - 1` and " + "num_spatial_dims was either provided or estimated as " + "`filter_shape.ndims - 2`. input_shape.ndims: {}, " + "num_spatial_dims: {}, filter_shape.ndims: {}" + .format(num_batch_dims, input_shape.ndims, num_spatial_dims, + filter_shape.ndims)) if data_format is None or not data_format.startswith("NC"): input_channels_dim = tensor_shape.dimension_at_index( - input_shape, num_spatial_dims + 1) - spatial_dims = range(1, num_spatial_dims + 1) + input_shape, num_spatial_dims + num_batch_dims) + spatial_dims = range(num_batch_dims, num_spatial_dims + num_batch_dims) else: - input_channels_dim = tensor_shape.dimension_at_index(input_shape, 1) - spatial_dims = range(2, num_spatial_dims + 2) + input_channels_dim = tensor_shape.dimension_at_index( + input_shape, num_batch_dims) + spatial_dims = range( + num_batch_dims + 1, num_spatial_dims + num_batch_dims + 1) if not input_channels_dim.is_compatible_with( filter_shape[num_spatial_dims]): raise ValueError( - "number of input channels does not match corresponding dimension of " + "Number of input channels does not match corresponding dimension of " "filter, {} != {}".format(input_channels_dim, filter_shape[num_spatial_dims])) @@ -1136,6 +1259,8 @@ class Convolution(object): self.padding = padding self.name = name self.dilation_rate = dilation_rate + self.num_batch_dims = num_batch_dims + self.num_spatial_dims = num_spatial_dims self.conv_op = _WithSpaceToBatch( input_shape, dilation_rate=dilation_rate, @@ -1143,7 +1268,8 @@ class Convolution(object): build_op=self._build_op, filter_shape=filter_shape, spatial_dims=spatial_dims, - data_format=data_format) + data_format=data_format, + num_batch_dims=num_batch_dims) def _build_op(self, _, padding): return _NonAtrousConvolution( @@ -1152,7 +1278,8 @@ class Convolution(object): padding=padding, data_format=self.data_format, strides=self.strides, - name=self.name) + name=self.name, + num_batch_dims=self.num_batch_dims) def __call__(self, inp, filter): # pylint: disable=redefined-builtin # TPU convolution supports dilations greater than 1. @@ -1165,7 +1292,8 @@ class Convolution(object): data_format=self.data_format, dilations=self.dilation_rate, name=self.name, - call_from_convolution=False) + call_from_convolution=False, + num_spatial_dims=self.num_spatial_dims) else: return self.conv_op(inp, filter) @@ -2392,6 +2520,42 @@ def conv2d_transpose_v2( name=name) +def _conv2d_expanded_batch( + input, # pylint: disable=redefined-builtin + filters, + strides, + padding, + data_format, + dilations, + name): + """Helper function for `convolution_internal`; handles expanded batches.""" + # Try really hard to avoid modifying the legacy name scopes - return early. + shape = getattr(input, "shape", None) + if shape is not None: + ndims = getattr(shape, "ndims", -1) + if ndims == -1: ndims = len(shape) + if ndims in (4, 3, 2, 1, 0, None): + return gen_nn_ops.conv2d( + input, + filter=filters, + strides=strides, + padding=padding, + data_format=data_format, + dilations=dilations, + name=name) + return _squeeze_batch_dims( + input, + functools.partial( + gen_nn_ops.conv2d, + filter=filters, + strides=strides, + padding=padding, + data_format=data_format, + dilations=dilations), + inner_rank=3, + name=name) + + @tf_export("nn.atrous_conv2d_transpose") @dispatch.add_dispatch_support def atrous_conv2d_transpose(value,