diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 1830fe28674..b7c8395790a 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -1432,6 +1432,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._run(fn, 10000) + def benchmark_tf_nn_convolution_overhead(self): + inputs = array_ops.ones((1, 1, 1, 1)) + filters = array_ops.ones((1, 1, 1, 1)) + + def fn(): + nn_ops.convolution_v2(inputs, filters) + + self._run(fn, 10000) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index bdc7a1677cd..51f4e3b320a 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -224,8 +224,8 @@ class Conv(Layer): tf_padding = self.padding.upper() else: tf_padding = self.padding - tf_dilations = self.dilation_rate - tf_strides = self.strides + tf_dilations = list(self.dilation_rate) + tf_strides = list(self.strides) tf_op_name = self.__class__.__name__ if tf_op_name == 'Conv1D': diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index a61ae753121..1318f575737 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -40,6 +40,7 @@ from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables as variables_lib # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_nn_ops import * @@ -58,25 +59,46 @@ local_response_normalization = gen_nn_ops.lrn # pylint: disable=protected-access +# Acceptable channels last formats (robust to H, W, D order). +_CHANNELS_LAST_FORMATS = frozenset({ + "NWC", "NHC", "NHWC", "NWHC", "NDHWC", "NDWHC", "NHDWC", "NHWDC", "NWDHC", + "NWHDC" +}) + def _get_sequence(value, n, channel_index, name): """Formats a value input for gen_nn_ops.""" + # Performance is fast-pathed for common cases: + # `None`, `list`, `tuple` and `int`. if value is None: - value = [1] + return [1] * (n + 2) + + # Always convert `value` to a `list`. + if isinstance(value, list): + pass + elif isinstance(value, tuple): + value = list(value) + elif isinstance(value, int): + value = [value] elif not isinstance(value, collections_abc.Sized): value = [value] - - current_n = len(value) - if current_n == n + 2: - return value - elif current_n == 1: - value = list((value[0],) * n) - elif current_n == n: - value = list(value) else: - raise ValueError("{} should be of length 1, {} or {} but was {}".format( - name, n, n + 2, current_n)) + value = list(value) # Try casting to a list. + len_value = len(value) + + # Fully specified, including batch and channel dims. + if len_value == n + 2: + return value + + # Apply value to spatial dims only. + if len_value == 1: + value = value * n # Broadcast to spatial dimensions. + elif len_value != n: + raise ValueError("{} should be of length 1, {} or {} but was {}".format( + name, n, n + 2, len_value)) + + # Add batch and channel dims (always 1). if channel_index == 1: return [1, 1] + value else: @@ -1042,71 +1064,80 @@ def convolution_internal( `num_spatial_dims` is provided and incompatible with the value estimated from `filters.shape`. """ - n = None - if getattr(filters, 'shape', None) is None: - with ops.name_scope(name, 'convolution_internal', [filters, input]): + if (not isinstance(filters, variables_lib.Variable) and + not tensor_util.is_tensor(filters)): + with ops.name_scope("convolution_internal", None, [filters, input]): filters = ops.convert_to_tensor(filters, name='filters') - if (isinstance(filters.shape, tensor_shape.TensorShape) - and filters.shape.rank is not None): - n = len(filters.shape) - 2 - elif (not isinstance(filters.shape, tensor_shape.TensorShape) - and filters.shape is not None): - n = len(filters.shape) - 2 + if (not isinstance(input, ops.Tensor) and not tensor_util.is_tensor(input)): + with ops.name_scope("convolution_internal", None, [filters, input]): + input = ops.convert_to_tensor(input, name="input") - if (isinstance(input.shape, tensor_shape.TensorShape) - and input.shape.rank is not None): - if n is None: - n = (num_spatial_dims if num_spatial_dims is not None - else len(input.shape) - 2) - num_batch_dims = len(input.shape) - n - 1 - elif (not isinstance(input.shape, tensor_shape.TensorShape) - and input.shape is not None): - if n is None: - n = (num_spatial_dims if num_spatial_dims is not None - else len(input.shape) - 2) - num_batch_dims = len(input.shape) - n - 1 - else: - num_batch_dims = 1 # Default behavior if it cannot be estimated. - - if n is None: - raise ValueError("rank of input or filter must be known") - - if num_spatial_dims is not None and n != num_spatial_dims: + filters_rank = filters.shape.rank + inputs_rank = input.shape.rank + if num_spatial_dims is None: + if filters_rank: + num_spatial_dims = filters_rank - 2 + elif inputs_rank: + num_spatial_dims = inputs_rank - 2 + else: + raise ValueError("rank of input or filter must be known") + elif filters_rank and filters_rank - 2 != num_spatial_dims: raise ValueError( "inconsistent estimate of spatial dims ({}) vs. actual passed " "num_spatial_dims ({}). n was estimated as len(filters.shape) - 2, " - "but filters shape is: {}".format(n, num_spatial_dims, filters.shape)) + "but filters shape is: {}".format(filters_rank, num_spatial_dims, + filters.shape)) - if not 1 <= n <= 3: + if inputs_rank: + num_batch_dims = inputs_rank - num_spatial_dims - 1 # Channel dimension. + else: + num_batch_dims = 1 # By default, assume single batch dimension. + + if num_spatial_dims not in {1, 2, 3}: raise ValueError( "num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one " - "of 1, 2 or 3 but saw {}. num_batch_dims: {}." - .format(n, num_batch_dims)) + "of 1, 2 or 3 but saw {}. num_batch_dims: {}.".format( + num_spatial_dims, num_batch_dims)) - if data_format is None: - channel_index = num_batch_dims + n + if data_format is None or data_format in _CHANNELS_LAST_FORMATS: + channel_index = num_batch_dims + num_spatial_dims else: - channel_index = ( - num_batch_dims if data_format.startswith("NC") else n + num_batch_dims) + channel_index = num_batch_dims - strides = _get_sequence(strides, n, channel_index, "strides") - dilations = _get_sequence(dilations, n, channel_index, "dilations") - - scopes = {1: "conv1d", 2: "Conv2D", 3: "Conv3D"} - if not call_from_convolution and device_context.enclosing_tpu_context( - ) is not None: - scope = scopes[n] + if dilations is None: + dilations = _get_sequence(dilations, num_spatial_dims, channel_index, + "dilations") + is_dilated_conv = False else: - scope = "convolution" + dilations = _get_sequence(dilations, num_spatial_dims, channel_index, + "dilations") + is_dilated_conv = any(i != 1 for i in dilations) - with ops.name_scope(name, scope, [input, filters]) as name: - conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: _conv3d_expanded_batch} + strides = _get_sequence(strides, num_spatial_dims, channel_index, "strides") + has_tpu_context = device_context.enclosing_tpu_context() is not None + + if name: + default_name = None + elif not has_tpu_context or call_from_convolution: + default_name = "convolution" + elif num_spatial_dims == 2: # Most common case. + default_name = "Conv2D" + elif num_spatial_dims == 3: + default_name = "Conv3D" + else: + default_name = "conv1d" + + with ops.name_scope(name, default_name, [input, filters]) as name: + # Fast path for TPU or if no dilation, as gradient only supported on TPU + # for dilations. + if not is_dilated_conv or has_tpu_context: + if num_spatial_dims == 2: # Most common case. + op = _conv2d_expanded_batch + elif num_spatial_dims == 3: + op = _conv3d_expanded_batch + else: + op = conv1d - if device_context.enclosing_tpu_context() is not None or all( - i == 1 for i in dilations): - # fast path for TPU or if no dilation as gradient only supported on GPU - # for dilations - op = conv_ops[n] return op( input, filters, @@ -1131,7 +1162,7 @@ def convolution_internal( dilation_rate=dilations, name=name, data_format=data_format, - num_spatial_dims=n) + num_spatial_dims=num_spatial_dims) return op(input, filters) @@ -2547,11 +2578,8 @@ def _conv2d_expanded_batch( name): """Helper function for `convolution_internal`; handles expanded batches.""" # Try really hard to avoid modifying the legacy name scopes - return early. - shape = getattr(input, "shape", None) - if shape is not None: - ndims = getattr(shape, "ndims", -1) - if ndims == -1: ndims = len(shape) - if ndims in (4, 3, 2, 1, 0, None): + input_rank = input.shape.rank + if input_rank is None or input_rank < 5: # We avoid calling squeeze_batch_dims to reduce extra python function # call slowdown in eager mode. This branch doesn't require reshapes. return gen_nn_ops.conv2d( diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index e672018bcf6..6ecd1e015d2 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -1502,11 +1502,8 @@ class MaxPoolTest(test_lib.TestCase): class ConvolutionTest(test_lib.TestCase): def testUnknownSize(self): - # explicitly use float32 for ROCm, as MIOpen does not yet support float64 - # np.ones defaults to using float64 when dtype is not explicitly specified - dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 x = tensor_spec.TensorSpec(None, dtypes.float32, name="x") - k = np.ones([3, 6, 6, 5], dtype=dtype) + k = np.ones([3, 6, 6, 5], dtype=np.float32) @def_function.function def F(value):