Reduce Layer.__call__ overhead by 5-10%

Only retrieve mask arg if Layer.call or Layer.compute_mask need it.
Skips checking for implicit masks entirely otherwise.

PiperOrigin-RevId: 313444769
Change-Id: Ife930d4c299dce6463836e0e238d236b7582b2ee
This commit is contained in:
Thomas O'Malley 2020-05-27 12:47:20 -07:00 committed by TensorFlower Gardener
parent 07ff32309a
commit 6f7765b1bd

View File

@ -850,11 +850,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
# explicitly take priority.
mask_arg_passed_by_framework = False
input_masks = self._collect_input_masks(inputs, input_list, args, kwargs)
if (self._expects_mask_arg and input_masks is not None and
not self._call_arg_was_passed('mask', args, kwargs)):
mask_arg_passed_by_framework = True
input_masks, mask_is_implicit = self._get_input_masks(
inputs, input_list, args, kwargs)
if self._expects_mask_arg and mask_is_implicit:
kwargs['mask'] = input_masks
mask_arg_passed_by_framework = True
# If `training` argument is None or not explicitly passed,
# propagate `training` value from this layer's calling layer.
@ -2312,20 +2312,26 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# Do not track masks for `TensorFlowOpLayer` construction.
output._keras_mask._keras_history_checked = True
def _collect_input_masks(self, inputs, input_list, args, kwargs):
"""Checks if `mask` argument was passed, else gathers mask from inputs."""
if self._call_arg_was_passed('mask', args, kwargs):
return self._get_call_arg_value('mask', args, kwargs)
if not self._should_compute_mask:
return None
input_masks = [getattr(t, '_keras_mask', None) for t in input_list]
if all(mask is None for mask in input_masks):
return None
# Only do expensive `nest` operation when masking is actually being used.
return nest.pack_sequence_as(inputs, input_masks)
def _get_input_masks(self, inputs, input_list, args, kwargs):
if (not self._expects_mask_arg and not self.supports_masking and
not self._compute_mask_overridden):
# Input masks only need to be retrieved if they are needed for `call`
# or `compute_mask`.
input_masks = None
implicit_mask = False
elif self._call_arg_was_passed('mask', args, kwargs):
input_masks = self._get_call_arg_value('mask', args, kwargs)
implicit_mask = False
else:
input_masks = [getattr(t, '_keras_mask', None) for t in input_list]
if all(mask is None for mask in input_masks):
input_masks = None
implicit_mask = False
else:
# Only do expensive `nest` op when masking is actually being used.
input_masks = nest.pack_sequence_as(inputs, input_masks)
implicit_mask = True
return input_masks, implicit_mask
def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
# Performance optimization: do no work in most common case.
@ -2751,12 +2757,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
def _call_accepts_kwargs(self):
return self._call_full_argspec.varkw is not None
@property
@tracking.cached_per_instance
def _should_compute_mask(self):
return ('mask' in self._call_fn_args or
getattr(self, 'compute_mask', None) is not None)
@property
def _eager_losses(self):
# A list of loss values containing activity regularizers and losses