Reduce Conv2D.__call__ eager overhead by ~20%, Conv3D.__call__ by ~35%, and

Conv1D.__call__ overhead by ~10%

Removes 12us of overhead from tf.nn.convolution

Changes to reduce the overhead of tf.nn.convolution:
- Better isinstance check ordering
- Caching of expensive attrs like shape, rank, and TPU context.
- Various smaller changes for faster conditional logic.

PiperOrigin-RevId: 316178077
Change-Id: I63efb501a84a95583dd98917e04547fd135cd6e1
This commit is contained in:
Thomas O'Malley 2020-06-12 14:29:11 -07:00 committed by TensorFlower Gardener
parent 722065ae57
commit aadef9318e
4 changed files with 108 additions and 74 deletions

View File

@ -1432,6 +1432,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
self._run(fn, 10000)
def benchmark_tf_nn_convolution_overhead(self):
inputs = array_ops.ones((1, 1, 1, 1))
filters = array_ops.ones((1, 1, 1, 1))
def fn():
nn_ops.convolution_v2(inputs, filters)
self._run(fn, 10000)
if __name__ == "__main__":
test.main()

View File

@ -224,8 +224,8 @@ class Conv(Layer):
tf_padding = self.padding.upper()
else:
tf_padding = self.padding
tf_dilations = self.dilation_rate
tf_strides = self.strides
tf_dilations = list(self.dilation_rate)
tf_strides = list(self.strides)
tf_op_name = self.__class__.__name__
if tf_op_name == 'Conv1D':

View File

@ -40,6 +40,7 @@ from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables as variables_lib
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_nn_ops import *
@ -58,25 +59,46 @@ local_response_normalization = gen_nn_ops.lrn
# pylint: disable=protected-access
# Acceptable channels last formats (robust to H, W, D order).
_CHANNELS_LAST_FORMATS = frozenset({
"NWC", "NHC", "NHWC", "NWHC", "NDHWC", "NDWHC", "NHDWC", "NHWDC", "NWDHC",
"NWHDC"
})
def _get_sequence(value, n, channel_index, name):
"""Formats a value input for gen_nn_ops."""
# Performance is fast-pathed for common cases:
# `None`, `list`, `tuple` and `int`.
if value is None:
value = [1]
return [1] * (n + 2)
# Always convert `value` to a `list`.
if isinstance(value, list):
pass
elif isinstance(value, tuple):
value = list(value)
elif isinstance(value, int):
value = [value]
elif not isinstance(value, collections_abc.Sized):
value = [value]
current_n = len(value)
if current_n == n + 2:
return value
elif current_n == 1:
value = list((value[0],) * n)
elif current_n == n:
value = list(value)
else:
raise ValueError("{} should be of length 1, {} or {} but was {}".format(
name, n, n + 2, current_n))
value = list(value) # Try casting to a list.
len_value = len(value)
# Fully specified, including batch and channel dims.
if len_value == n + 2:
return value
# Apply value to spatial dims only.
if len_value == 1:
value = value * n # Broadcast to spatial dimensions.
elif len_value != n:
raise ValueError("{} should be of length 1, {} or {} but was {}".format(
name, n, n + 2, len_value))
# Add batch and channel dims (always 1).
if channel_index == 1:
return [1, 1] + value
else:
@ -1042,71 +1064,80 @@ def convolution_internal(
`num_spatial_dims` is provided and incompatible with the value
estimated from `filters.shape`.
"""
n = None
if getattr(filters, 'shape', None) is None:
with ops.name_scope(name, 'convolution_internal', [filters, input]):
if (not isinstance(filters, variables_lib.Variable) and
not tensor_util.is_tensor(filters)):
with ops.name_scope("convolution_internal", None, [filters, input]):
filters = ops.convert_to_tensor(filters, name='filters')
if (isinstance(filters.shape, tensor_shape.TensorShape)
and filters.shape.rank is not None):
n = len(filters.shape) - 2
elif (not isinstance(filters.shape, tensor_shape.TensorShape)
and filters.shape is not None):
n = len(filters.shape) - 2
if (not isinstance(input, ops.Tensor) and not tensor_util.is_tensor(input)):
with ops.name_scope("convolution_internal", None, [filters, input]):
input = ops.convert_to_tensor(input, name="input")
if (isinstance(input.shape, tensor_shape.TensorShape)
and input.shape.rank is not None):
if n is None:
n = (num_spatial_dims if num_spatial_dims is not None
else len(input.shape) - 2)
num_batch_dims = len(input.shape) - n - 1
elif (not isinstance(input.shape, tensor_shape.TensorShape)
and input.shape is not None):
if n is None:
n = (num_spatial_dims if num_spatial_dims is not None
else len(input.shape) - 2)
num_batch_dims = len(input.shape) - n - 1
else:
num_batch_dims = 1 # Default behavior if it cannot be estimated.
if n is None:
raise ValueError("rank of input or filter must be known")
if num_spatial_dims is not None and n != num_spatial_dims:
filters_rank = filters.shape.rank
inputs_rank = input.shape.rank
if num_spatial_dims is None:
if filters_rank:
num_spatial_dims = filters_rank - 2
elif inputs_rank:
num_spatial_dims = inputs_rank - 2
else:
raise ValueError("rank of input or filter must be known")
elif filters_rank and filters_rank - 2 != num_spatial_dims:
raise ValueError(
"inconsistent estimate of spatial dims ({}) vs. actual passed "
"num_spatial_dims ({}). n was estimated as len(filters.shape) - 2, "
"but filters shape is: {}".format(n, num_spatial_dims, filters.shape))
"but filters shape is: {}".format(filters_rank, num_spatial_dims,
filters.shape))
if not 1 <= n <= 3:
if inputs_rank:
num_batch_dims = inputs_rank - num_spatial_dims - 1 # Channel dimension.
else:
num_batch_dims = 1 # By default, assume single batch dimension.
if num_spatial_dims not in {1, 2, 3}:
raise ValueError(
"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one "
"of 1, 2 or 3 but saw {}. num_batch_dims: {}."
.format(n, num_batch_dims))
"of 1, 2 or 3 but saw {}. num_batch_dims: {}.".format(
num_spatial_dims, num_batch_dims))
if data_format is None:
channel_index = num_batch_dims + n
if data_format is None or data_format in _CHANNELS_LAST_FORMATS:
channel_index = num_batch_dims + num_spatial_dims
else:
channel_index = (
num_batch_dims if data_format.startswith("NC") else n + num_batch_dims)
channel_index = num_batch_dims
strides = _get_sequence(strides, n, channel_index, "strides")
dilations = _get_sequence(dilations, n, channel_index, "dilations")
scopes = {1: "conv1d", 2: "Conv2D", 3: "Conv3D"}
if not call_from_convolution and device_context.enclosing_tpu_context(
) is not None:
scope = scopes[n]
if dilations is None:
dilations = _get_sequence(dilations, num_spatial_dims, channel_index,
"dilations")
is_dilated_conv = False
else:
scope = "convolution"
dilations = _get_sequence(dilations, num_spatial_dims, channel_index,
"dilations")
is_dilated_conv = any(i != 1 for i in dilations)
with ops.name_scope(name, scope, [input, filters]) as name:
conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: _conv3d_expanded_batch}
strides = _get_sequence(strides, num_spatial_dims, channel_index, "strides")
has_tpu_context = device_context.enclosing_tpu_context() is not None
if name:
default_name = None
elif not has_tpu_context or call_from_convolution:
default_name = "convolution"
elif num_spatial_dims == 2: # Most common case.
default_name = "Conv2D"
elif num_spatial_dims == 3:
default_name = "Conv3D"
else:
default_name = "conv1d"
with ops.name_scope(name, default_name, [input, filters]) as name:
# Fast path for TPU or if no dilation, as gradient only supported on TPU
# for dilations.
if not is_dilated_conv or has_tpu_context:
if num_spatial_dims == 2: # Most common case.
op = _conv2d_expanded_batch
elif num_spatial_dims == 3:
op = _conv3d_expanded_batch
else:
op = conv1d
if device_context.enclosing_tpu_context() is not None or all(
i == 1 for i in dilations):
# fast path for TPU or if no dilation as gradient only supported on GPU
# for dilations
op = conv_ops[n]
return op(
input,
filters,
@ -1131,7 +1162,7 @@ def convolution_internal(
dilation_rate=dilations,
name=name,
data_format=data_format,
num_spatial_dims=n)
num_spatial_dims=num_spatial_dims)
return op(input, filters)
@ -2547,11 +2578,8 @@ def _conv2d_expanded_batch(
name):
"""Helper function for `convolution_internal`; handles expanded batches."""
# Try really hard to avoid modifying the legacy name scopes - return early.
shape = getattr(input, "shape", None)
if shape is not None:
ndims = getattr(shape, "ndims", -1)
if ndims == -1: ndims = len(shape)
if ndims in (4, 3, 2, 1, 0, None):
input_rank = input.shape.rank
if input_rank is None or input_rank < 5:
# We avoid calling squeeze_batch_dims to reduce extra python function
# call slowdown in eager mode. This branch doesn't require reshapes.
return gen_nn_ops.conv2d(

View File

@ -1502,11 +1502,8 @@ class MaxPoolTest(test_lib.TestCase):
class ConvolutionTest(test_lib.TestCase):
def testUnknownSize(self):
# explicitly use float32 for ROCm, as MIOpen does not yet support float64
# np.ones defaults to using float64 when dtype is not explicitly specified
dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64
x = tensor_spec.TensorSpec(None, dtypes.float32, name="x")
k = np.ones([3, 6, 6, 5], dtype=dtype)
k = np.ones([3, 6, 6, 5], dtype=np.float32)
@def_function.function
def F(value):