diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index 7623f72925e..b42aa270aa5 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -228,7 +228,7 @@ class Conv(Layer): # Apply causal padding to inputs for Conv1D. if self.padding == 'causal' and self.__class__.__name__ == 'Conv1D': - inputs = array_ops.pad(inputs, self._compute_causal_padding()) + inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs)) outputs = self._convolution_op(inputs, self.kernel) @@ -324,13 +324,17 @@ class Conv(Layer): base_config = super(Conv, self).get_config() return dict(list(base_config.items()) + list(config.items())) - def _compute_causal_padding(self): + def _compute_causal_padding(self, inputs): """Calculates padding for 'causal' option for 1-d conv layers.""" left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1) - if self.data_format == 'channels_last': - causal_padding = [[0, 0], [left_pad, 0], [0, 0]] + if getattr(inputs.shape, 'ndims', None) is None: + batch_rank = 1 else: - causal_padding = [[0, 0], [0, 0], [left_pad, 0]] + batch_rank = len(inputs.shape) - 2 + if self.data_format == 'channels_last': + causal_padding = [[0, 0]] * batch_rank + [[left_pad, 0], [0, 0]] + else: + causal_padding = [[0, 0]] * batch_rank + [[0, 0], [left_pad, 0]] return causal_padding def _get_channel_axis(self): @@ -404,6 +408,16 @@ class Conv1D(Conv): >>> print(y.shape) (4, 8, 32) + >>> # With extended batch shape [4, 7] (e.g. weather data where batch + >>> # dimensions correspond to spatial location and the third dimension + >>> # corresponds to time.) + >>> input_shape = (4, 7, 10, 128) + >>> x = tf.random.normal(input_shape) + >>> y = tf.keras.layers.Conv1D( + ... 32, 3, activation='relu', input_shape=input_shape[2:])(x) + >>> print(y.shape) + (4, 7, 8, 32) + Arguments: filters: Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution). @@ -451,10 +465,10 @@ class Conv1D(Conv): see `keras.constraints`). Input shape: - 3D tensor with shape: `(batch_size, steps, input_dim)` + 3+D tensor with shape: `batch_shape + (steps, input_dim)` Output shape: - 3D tensor with shape: `(batch_size, new_steps, filters)` + 3+D tensor with shape: `batch_shape + (new_steps, filters)` `steps` value might have changed due to padding or strides. Returns: @@ -462,7 +476,7 @@ class Conv1D(Conv): `activation(conv1d(inputs, kernel) + bias)`. Raises: - ValueError: when both `strides` > 1 and `dilation_rate` > 1. + ValueError: when both `strides > 1` and `dilation_rate > 1`. """ def __init__(self, @@ -702,6 +716,15 @@ class Conv3D(Conv): >>> print(y.shape) (4, 26, 26, 26, 2) + >>> # With extended batch shape [4, 7], e.g. a batch of 4 videos of 3D frames, + >>> # with 7 frames per video. + >>> input_shape = (4, 7, 28, 28, 28, 1) + >>> x = tf.random.normal(input_shape) + >>> y = tf.keras.layers.Conv3D( + ... 2, 3, activation='relu', input_shape=input_shape[2:])(x) + >>> print(y.shape) + (4, 7, 26, 26, 26, 2) + Arguments: filters: Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution). @@ -721,9 +744,9 @@ class Conv3D(Conv): one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + `batch_shape + (spatial_dim1, spatial_dim2, spatial_dim3, channels)` while `channels_first` corresponds to inputs with shape - `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. + `batch_shape + (channels, spatial_dim1, spatial_dim2, spatial_dim3)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". @@ -760,30 +783,30 @@ class Conv3D(Conv): see `keras.constraints`). Input shape: - 5D tensor with shape: - `(batch_size, channels, conv_dim1, conv_dim2, conv_dim3)` if + 5+D tensor with shape: + `batch_shape + (channels, conv_dim1, conv_dim2, conv_dim3)` if data_format='channels_first' - or 5D tensor with shape: - `(batch_size, conv_dim1, conv_dim2, conv_dim3, channels)` if + or 5+D tensor with shape: + `batch_shape + (conv_dim1, conv_dim2, conv_dim3, channels)` if data_format='channels_last'. Output shape: - 5D tensor with shape: - `(batch_size, filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if + 5+D tensor with shape: + `batch_shape + (filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if data_format='channels_first' - or 5D tensor with shape: - `(batch_size, new_conv_dim1, new_conv_dim2, new_conv_dim3, filters)` if + or 5+D tensor with shape: + `batch_shape + (new_conv_dim1, new_conv_dim2, new_conv_dim3, filters)` if data_format='channels_last'. `new_conv_dim1`, `new_conv_dim2` and `new_conv_dim3` values might have changed due to padding. Returns: - A tensor of rank 5 representing + A tensor of rank 5+ representing `activation(conv3d(inputs, kernel) + bias)`. Raises: ValueError: if `padding` is "causal". - ValueError: when both `strides` > 1 and `dilation_rate` > 1. + ValueError: when both `strides > 1` and `dilation_rate > 1`. """ def __init__(self, diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py index 7d54184aa9c..51d6710de27 100644 --- a/tensorflow/python/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/layers/convolutional_test.py @@ -49,6 +49,21 @@ class Conv1DTest(keras_parameterized.TestCase): input_shape=(num_samples, length, stack_size), expected_output_shape=expected_output_shape) + def _run_test_extra_batch_dim(self, kwargs, expected_output_shape): + batch_shape = (2, 11) + stack_size = 3 + length = 7 + + with self.cached_session(use_gpu=True): + if expected_output_shape is not None: + expected_output_shape = (None,) + expected_output_shape + + testing_utils.layer_test( + keras.layers.Conv1D, + kwargs=kwargs, + input_shape=batch_shape + (length, stack_size), + expected_output_shape=expected_output_shape) + @parameterized.named_parameters( ('padding_valid', { 'padding': 'valid' @@ -85,6 +100,7 @@ class Conv1DTest(keras_parameterized.TestCase): kwargs['kernel_size'] = 3 if not requires_gpu or test.is_gpu_available(cuda_only=True): self._run_test(kwargs, expected_output_shape) + self._run_test_extra_batch_dim(kwargs, expected_output_shape) def test_conv1d_regularizers(self): kwargs = { @@ -281,6 +297,27 @@ class Conv3DTest(keras_parameterized.TestCase): expected_output_shape=expected_output_shape, validate_training=validate_training) + def _run_test_extra_batch_dim(self, + kwargs, + expected_output_shape, + validate_training=True): + batch_shape = (2, 11) + stack_size = 3 + num_row = 7 + num_col = 6 + depth = 5 + + with self.cached_session(use_gpu=True): + if expected_output_shape is not None: + expected_output_shape = (None,) + expected_output_shape + + testing_utils.layer_test( + keras.layers.Conv3D, + kwargs=kwargs, + input_shape=batch_shape + (depth, num_row, num_col, stack_size), + expected_output_shape=expected_output_shape, + validate_training=validate_training) + @parameterized.named_parameters( ('padding_valid', { 'padding': 'valid' @@ -313,6 +350,8 @@ class Conv3DTest(keras_parameterized.TestCase): test_training = 'groups' not in kwargs or not test_util.is_xla_enabled() if not requires_gpu or test.is_gpu_available(cuda_only=True): self._run_test(kwargs, expected_output_shape, test_training) + self._run_test_extra_batch_dim(kwargs, expected_output_shape, + test_training) def test_conv3d_regularizers(self): kwargs = { diff --git a/tensorflow/python/kernel_tests/conv1d_test.py b/tensorflow/python/kernel_tests/conv1d_test.py index 5ac8c11130b..78b184a43ba 100644 --- a/tensorflow/python/kernel_tests/conv1d_test.py +++ b/tensorflow/python/kernel_tests/conv1d_test.py @@ -55,6 +55,32 @@ class Conv1DTest(test.TestCase): self.assertEqual(len(output), 2) self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4]) + def testExpandedBatch(self): + """Test that argument passing to conv1d is handled properly.""" + # double datatype is currently not supported for convolution ops + # on the ROCm platform + x = constant_op.constant([1, 2, 3, 4], dtype=dtypes.float32) + x = array_ops.expand_dims(x, 0) # Add batch dimension + x = array_ops.expand_dims(x, 2) # And depth dimension + x = array_ops.stack([x, x]) # Make batch shape [2, 1] + filters = constant_op.constant([2, 1], dtype=dtypes.float32) + filters = array_ops.expand_dims(filters, 1) # in_channels + filters = array_ops.expand_dims(filters, 2) # out_channels + # Filters is 2x1x1 + for stride in [1, 2]: + with self.cached_session(use_gpu=test.is_gpu_available()): + c = nn_ops.conv1d(x, filters, stride, padding="VALID") + reduced = array_ops.squeeze(c) # Sequeeze out dims 1 and 3. + output = self.evaluate(reduced) + if stride == 1: + self.assertAllClose(output, + [[2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4], + [2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4]]) + else: + self.assertAllClose( + output, + [[2 * 1 + 1 * 2, 2 * 3 + 1 * 4], [2 * 1 + 1 * 2, 2 * 3 + 1 * 4]]) + def testConv1DTranspose(self): with self.cached_session(): stride = 2 diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py index ce46f1d2782..ff4da3afc9f 100644 --- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py @@ -66,12 +66,8 @@ class Conv3DTest(test.TestCase): def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, stride, padding, data_format, dtype, use_gpu): - total_size_tensor = 1 - total_size_filter = 1 - for s in tensor_in_sizes: - total_size_tensor *= s - for s in filter_in_sizes: - total_size_filter *= s + total_size_tensor = np.prod(tensor_in_sizes) + total_size_filter = np.prod(filter_in_sizes) # Initializes the input tensor with array containing numbers from 0 to 1. # We keep the input tensor values fairly small to avoid overflowing float16 @@ -126,12 +122,8 @@ class Conv3DTest(test.TestCase): def _ComputeReferenceDilatedConv(self, tensor_in_sizes, filter_in_sizes, stride, dilation, padding, data_format, use_gpu): - total_size_tensor = 1 - total_size_filter = 1 - for s in tensor_in_sizes: - total_size_tensor *= s - for s in filter_in_sizes: - total_size_filter *= s + total_size_tensor = np.prod(tensor_in_sizes) + total_size_filter = np.prod(filter_in_sizes) # Initializes the input tensor with array containing incrementing # numbers from 1. @@ -196,6 +188,69 @@ class Conv3DTest(test.TestCase): self.assertAllClose( e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-6) + def _CreateNumpyTensor(self, sizes): + return np.asarray([f * 1.0 + for f in range(1, + np.prod(sizes) + 1)]).reshape(sizes) + + @test_util.run_in_graph_and_eager_modes + def testConv3DExpandedBatch(self): + tensor_in_sizes_batch = [10, 2, 3, 1, 3] + tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3] + filter_in_sizes = [1, 1, 1, 3, 3] + filter_in = self._CreateNumpyTensor(filter_in_sizes) + x1 = self._CreateNumpyTensor(tensor_in_sizes_batch) + x2 = x1.reshape(tensor_in_sizes_expanded_batch) + conv1 = nn_ops.conv3d_v2( + x1, filter_in, strides=[1, 1, 1, 1, 1], padding="VALID") + conv2 = nn_ops.conv3d_v2( + x2, filter_in, strides=[1, 1, 1, 1, 1], padding="VALID") + self.assertEqual(conv1.shape, tensor_in_sizes_batch) + self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch) + self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape)) + + @test_util.run_in_graph_and_eager_modes + def testConvolutionClass3DExpandedBatch(self): + tensor_in_sizes_batch = [10, 2, 3, 1, 3] + tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3] + filter_in_sizes = [1, 1, 1, 3, 3] + filter_in = self._CreateNumpyTensor(filter_in_sizes) + x1 = self._CreateNumpyTensor(tensor_in_sizes_batch) + x2 = x1.reshape(tensor_in_sizes_expanded_batch) + convolver1 = nn_ops.Convolution( + input_shape=x1.shape, + filter_shape=filter_in.shape, + strides=[1, 1, 1], + padding="VALID") + self.assertEqual(convolver1.num_batch_dims, 1) + convolver2 = nn_ops.Convolution( + input_shape=x2.shape, + filter_shape=filter_in.shape, + strides=[1, 1, 1], + padding="VALID") + self.assertEqual(convolver2.num_batch_dims, 2) + conv1 = convolver1(x1, filter_in) + conv2 = convolver2(x2, filter_in) + self.assertEqual(conv1.shape, tensor_in_sizes_batch) + self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch) + self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape)) + + @test_util.run_in_graph_and_eager_modes + def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self): + tensor_in_sizes_batch = [10, 2, 3, 1, 3] + tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3] + filter_in_sizes = [1, 1, 1, 3, 3] + filter_in = self._CreateNumpyTensor(filter_in_sizes) + x1 = self._CreateNumpyTensor(tensor_in_sizes_batch) + x2 = x1.reshape(tensor_in_sizes_expanded_batch) + conv1 = nn_ops.convolution( + x1, filter_in, strides=[1, 1, 1], padding="VALID") + conv2 = nn_ops.convolution( + x2, filter_in, strides=[1, 1, 1], padding="VALID") + self.assertEqual(conv1.shape, tensor_in_sizes_batch) + self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch) + self.assertAllEqual(conv1, self.evaluate(conv2).reshape(conv1.shape)) + def testConv3D1x1x1Filter(self): expected_output = [ 0.18518519, 0.22222222, 0.25925926, 0.40740741, 0.5, 0.59259259, diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 589f9984be1..c29fcb1f1d6 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -228,7 +228,7 @@ class _NonAtrousConvolution(object): % data_format) self.strides = strides self.data_format = data_format - self.conv_op = gen_nn_ops.conv3d + self.conv_op = _conv3d_expanded_batch # Note that we need this adapter since argument names for conv1d don't match # those for gen_nn_ops.conv2d and gen_nn_ops.conv3d. @@ -241,7 +241,6 @@ class _NonAtrousConvolution(object): padding=padding, data_format=data_format, name=name) - # pylint: enable=redefined-builtin def __call__(self, inp, filter): # pylint: disable=redefined-builtin @@ -1104,7 +1103,7 @@ def convolution_internal( scope = "convolution" with ops.name_scope(name, scope, [input, filters]) as name: - conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: gen_nn_ops.conv3d} + conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: _conv3d_expanded_batch} if device_context.enclosing_tpu_context() is not None or all( i == 1 for i in dilations): @@ -1783,40 +1782,42 @@ def conv1d( name=None, input=None, # pylint: disable=redefined-builtin dilations=None): - r"""Computes a 1-D convolution given 3-D input and filter tensors. + r"""Computes a 1-D convolution of input with rank `>=3` and a `3-D` filter. Given an input tensor of shape - [batch, in_width, in_channels] - if data_format is "NWC", or - [batch, in_channels, in_width] - if data_format is "NCW", + `batch_shape + [in_width, in_channels]` + if `data_format` is `"NWC"`, or + `batch_shape + [in_channels, in_width]` + if `data_format` is `"NCW"`, and a filter / kernel tensor of shape - [filter_width, in_channels, out_channels], this op reshapes - the arguments to pass them to conv2d to perform the equivalent + `[filter_width, in_channels, out_channels]`, this op reshapes + the arguments to pass them to `conv2d` to perform the equivalent convolution operation. Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`. For example, if `data_format` does not start with "NC", a tensor of shape - [batch, in_width, in_channels] + `batch_shape + [in_width, in_channels]` is reshaped to - [batch, 1, in_width, in_channels], + `batch_shape + [1, in_width, in_channels]`, and the filter is reshaped to - [1, filter_width, in_channels, out_channels]. + `[1, filter_width, in_channels, out_channels]`. The result is then reshaped back to - [batch, out_width, out_channels] + `batch_shape + [out_width, out_channels]` \(where out_width is a function of the stride and padding as in conv2d\) and returned to the caller. Args: - value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`. - filters: A 3D `Tensor`. Must have the same type as `value`. + value: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or + `float64`. + filters: A Tensor of rank at least 3. Must have the same type as `value`. stride: An int or list of `ints` that has length `1` or `3`. The number of entries by which the filter is moved right at each step. padding: 'SAME' or 'VALID' use_cudnn_on_gpu: An optional `bool`. Defaults to `True`. data_format: An optional `string` from `"NWC", "NCW"`. Defaults to `"NWC"`, - the data is stored in the order of [batch, in_width, in_channels]. The - `"NCW"` format stores data as [batch, in_channels, in_width]. + the data is stored in the order of `batch_shape + [in_width, + in_channels]`. The `"NCW"` format stores data as `batch_shape + + [in_channels, in_width]`. name: A name for the operation (optional). input: Alias for value. dilations: An int or list of `ints` that has length `1` or `3` which @@ -1832,14 +1833,14 @@ def conv1d( """ value = deprecation.deprecated_argument_lookup("input", input, "value", value) with ops.name_scope(name, "conv1d", [value, filters]) as name: - # Reshape the input tensor to [batch, 1, in_width, in_channels] + # Reshape the input tensor to batch_shape + [1, in_width, in_channels] if data_format is None or data_format == "NHWC" or data_format == "NWC": data_format = "NHWC" - spatial_start_dim = 1 + spatial_start_dim = -3 channel_index = 2 elif data_format == "NCHW" or data_format == "NCW": data_format = "NCHW" - spatial_start_dim = 2 + spatial_start_dim = -2 channel_index = 1 else: raise ValueError("data_format must be \"NWC\" or \"NCW\".") @@ -1848,15 +1849,30 @@ def conv1d( value = array_ops.expand_dims(value, spatial_start_dim) filters = array_ops.expand_dims(filters, 0) - result = gen_nn_ops.conv2d( - value, - filters, - strides, - padding, - use_cudnn_on_gpu=use_cudnn_on_gpu, - data_format=data_format, - dilations=dilations, - name=name) + if value.shape.ndims in (4, 3, 2, 1, 0, None): + result = gen_nn_ops.conv2d( + value, + filters, + strides, + padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + data_format=data_format, + dilations=dilations, + name=name) + else: + result = squeeze_batch_dims( + value, + functools.partial( + gen_nn_ops.conv2d, + filter=filters, + strides=strides, + padding=padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + data_format=data_format, + dilations=dilations, + ), + inner_rank=3, + name=name) return array_ops.squeeze(result, [spatial_start_dim]) @@ -1873,36 +1889,38 @@ def conv1d_v2( r"""Computes a 1-D convolution given 3-D input and filter tensors. Given an input tensor of shape - [batch, in_width, in_channels] - if data_format is "NWC", or - [batch, in_channels, in_width] - if data_format is "NCW", + `batch_shape + [in_width, in_channels]` + if `data_format` is `"NWC"`, or + `batch_shape + [in_channels, in_width]` + if `data_format` is `"NCW"`, and a filter / kernel tensor of shape - [filter_width, in_channels, out_channels], this op reshapes - the arguments to pass them to conv2d to perform the equivalent + `[filter_width, in_channels, out_channels]`, this op reshapes + the arguments to pass them to `conv2d` to perform the equivalent convolution operation. Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`. - For example, if `data_format` does not start with "NC", a tensor of shape - [batch, in_width, in_channels] + For example, if `data_format` does not start with `"NC"`, a tensor of shape + `batch_shape + [in_width, in_channels]` is reshaped to - [batch, 1, in_width, in_channels], + `batch_shape + [1, in_width, in_channels]`, and the filter is reshaped to - [1, filter_width, in_channels, out_channels]. + `[1, filter_width, in_channels, out_channels]`. The result is then reshaped back to - [batch, out_width, out_channels] + `batch_shape + [out_width, out_channels]` \(where out_width is a function of the stride and padding as in conv2d\) and returned to the caller. Args: - input: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`. - filters: A 3D `Tensor`. Must have the same type as `input`. + input: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or + `float64`. + filters: A Tensor of rank at least 3. Must have the same type as `input`. stride: An int or list of `ints` that has length `1` or `3`. The number of entries by which the filter is moved right at each step. padding: 'SAME' or 'VALID' data_format: An optional `string` from `"NWC", "NCW"`. Defaults to `"NWC"`, - the data is stored in the order of [batch, in_width, in_channels]. The - `"NCW"` format stores data as [batch, in_channels, in_width]. + the data is stored in the order of + `batch_shape + [in_width, in_channels]`. The `"NCW"` format stores data + as `batch_shape + [in_channels, in_width]`. dilations: An int or list of `ints` that has length `1` or `3` which defaults to 1. The dilation factor for each dimension of input. If set to k > 1, there will be k-1 skipped cells between each filter element on that @@ -2071,9 +2089,9 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin Args: input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`, `float32`, `float64`. - A 4+-D tensor. The dimension order is interpreted according to the value - of `data_format`; with the all-but-inner-3 dimensions acting as batch - dimensions. See below for details. + A Tensor of rank at least 4. The dimension order is interpreted according + to the value of `data_format`; with the all-but-inner-3 dimensions acting + as batch dimensions. See below for details. filters: A `Tensor`. Must have the same type as `input`. A 4-D tensor of shape `[filter_height, filter_width, in_channels, out_channels]` @@ -2214,6 +2232,8 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value ndims = getattr(shape, "ndims", -1) if ndims == -1: ndims = len(shape) if ndims in (4, 3, 2, 1, 0, None): + # We avoid calling squeeze_batch_dims to reduce extra python function + # call slowdown in eager mode. This branch doesn't require reshapes. return gen_nn_ops.conv2d( input, filter=filter, @@ -2535,6 +2555,8 @@ def _conv2d_expanded_batch( ndims = getattr(shape, "ndims", -1) if ndims == -1: ndims = len(shape) if ndims in (4, 3, 2, 1, 0, None): + # We avoid calling squeeze_batch_dims to reduce extra python function + # call slowdown in eager mode. This branch doesn't require reshapes. return gen_nn_ops.conv2d( input, filter=filters, @@ -2931,6 +2953,46 @@ def depthwise_conv2d_native_backprop_filter( # pylint: disable=redefined-builti name=name) +def _conv3d_expanded_batch( + input, # pylint: disable=redefined-builtin + filter, # pylint: disable=redefined-builtin + strides, + padding, + data_format, + dilations=None, + name=None): + """Helper function for `conv3d`; handles expanded batches.""" + # Try really hard to avoid modifying the legacy name sceops - return early. + shape = getattr(input, "shape", None) + if shape is not None: + ndims = getattr(shape, "ndims", -1) + if ndims == -1: + ndims = len(shape) + if ndims in (5, 4, 3, 2, 1, 0, None): + # We avoid calling squeeze_batch_dims to reduce extra python function + # call slowdown in eager mode. This branch doesn't require reshapes. + return gen_nn_ops.conv3d( + input, + filter, + strides, + padding, + data_format=data_format, + dilations=dilations, + name=name) + else: + return squeeze_batch_dims( + input, + functools.partial( + gen_nn_ops.conv3d, + filter=filter, + strides=strides, + padding=padding, + data_format=data_format, + dilations=dilations), + inner_rank=4, + name=name) + + @tf_export("nn.conv3d", v1=[]) @dispatch.add_dispatch_support def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring @@ -2942,13 +3004,8 @@ def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring name=None): if dilations is None: dilations = [1, 1, 1, 1, 1] - return gen_nn_ops.conv3d(input, - filters, - strides, - padding, - data_format=data_format, - dilations=dilations, - name=name) + return _conv3d_expanded_batch(input, filters, strides, padding, data_format, + dilations, name) @tf_export(v1=["nn.conv3d"])