Reduce Layer.__call__ overhead by 5-10%
Only retrieve mask arg if Layer.call or Layer.compute_mask need it. Skips checking for implicit masks entirely otherwise. PiperOrigin-RevId: 313444769 Change-Id: Ife930d4c299dce6463836e0e238d236b7582b2ee
This commit is contained in:
parent
07ff32309a
commit
6f7765b1bd
@ -850,11 +850,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
# setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
|
||||
# explicitly take priority.
|
||||
mask_arg_passed_by_framework = False
|
||||
input_masks = self._collect_input_masks(inputs, input_list, args, kwargs)
|
||||
if (self._expects_mask_arg and input_masks is not None and
|
||||
not self._call_arg_was_passed('mask', args, kwargs)):
|
||||
mask_arg_passed_by_framework = True
|
||||
input_masks, mask_is_implicit = self._get_input_masks(
|
||||
inputs, input_list, args, kwargs)
|
||||
if self._expects_mask_arg and mask_is_implicit:
|
||||
kwargs['mask'] = input_masks
|
||||
mask_arg_passed_by_framework = True
|
||||
|
||||
# If `training` argument is None or not explicitly passed,
|
||||
# propagate `training` value from this layer's calling layer.
|
||||
@ -2312,20 +2312,26 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
# Do not track masks for `TensorFlowOpLayer` construction.
|
||||
output._keras_mask._keras_history_checked = True
|
||||
|
||||
def _collect_input_masks(self, inputs, input_list, args, kwargs):
|
||||
"""Checks if `mask` argument was passed, else gathers mask from inputs."""
|
||||
if self._call_arg_was_passed('mask', args, kwargs):
|
||||
return self._get_call_arg_value('mask', args, kwargs)
|
||||
|
||||
if not self._should_compute_mask:
|
||||
return None
|
||||
|
||||
input_masks = [getattr(t, '_keras_mask', None) for t in input_list]
|
||||
if all(mask is None for mask in input_masks):
|
||||
return None
|
||||
|
||||
# Only do expensive `nest` operation when masking is actually being used.
|
||||
return nest.pack_sequence_as(inputs, input_masks)
|
||||
def _get_input_masks(self, inputs, input_list, args, kwargs):
|
||||
if (not self._expects_mask_arg and not self.supports_masking and
|
||||
not self._compute_mask_overridden):
|
||||
# Input masks only need to be retrieved if they are needed for `call`
|
||||
# or `compute_mask`.
|
||||
input_masks = None
|
||||
implicit_mask = False
|
||||
elif self._call_arg_was_passed('mask', args, kwargs):
|
||||
input_masks = self._get_call_arg_value('mask', args, kwargs)
|
||||
implicit_mask = False
|
||||
else:
|
||||
input_masks = [getattr(t, '_keras_mask', None) for t in input_list]
|
||||
if all(mask is None for mask in input_masks):
|
||||
input_masks = None
|
||||
implicit_mask = False
|
||||
else:
|
||||
# Only do expensive `nest` op when masking is actually being used.
|
||||
input_masks = nest.pack_sequence_as(inputs, input_masks)
|
||||
implicit_mask = True
|
||||
return input_masks, implicit_mask
|
||||
|
||||
def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
|
||||
# Performance optimization: do no work in most common case.
|
||||
@ -2751,12 +2757,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
def _call_accepts_kwargs(self):
|
||||
return self._call_full_argspec.varkw is not None
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def _should_compute_mask(self):
|
||||
return ('mask' in self._call_fn_args or
|
||||
getattr(self, 'compute_mask', None) is not None)
|
||||
|
||||
@property
|
||||
def _eager_losses(self):
|
||||
# A list of loss values containing activity regularizers and losses
|
||||
|
Loading…
Reference in New Issue
Block a user