From 6f7765b1bdaa3ed34958585311e69cfc53137405 Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Wed, 27 May 2020 12:47:20 -0700 Subject: [PATCH] Reduce Layer.__call__ overhead by 5-10% Only retrieve mask arg if Layer.call or Layer.compute_mask need it. Skips checking for implicit masks entirely otherwise. PiperOrigin-RevId: 313444769 Change-Id: Ife930d4c299dce6463836e0e238d236b7582b2ee --- tensorflow/python/keras/engine/base_layer.py | 48 ++++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index b986f9a405e..4a43b0526f6 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -850,11 +850,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed # explicitly take priority. mask_arg_passed_by_framework = False - input_masks = self._collect_input_masks(inputs, input_list, args, kwargs) - if (self._expects_mask_arg and input_masks is not None and - not self._call_arg_was_passed('mask', args, kwargs)): - mask_arg_passed_by_framework = True + input_masks, mask_is_implicit = self._get_input_masks( + inputs, input_list, args, kwargs) + if self._expects_mask_arg and mask_is_implicit: kwargs['mask'] = input_masks + mask_arg_passed_by_framework = True # If `training` argument is None or not explicitly passed, # propagate `training` value from this layer's calling layer. @@ -2312,20 +2312,26 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # Do not track masks for `TensorFlowOpLayer` construction. output._keras_mask._keras_history_checked = True - def _collect_input_masks(self, inputs, input_list, args, kwargs): - """Checks if `mask` argument was passed, else gathers mask from inputs.""" - if self._call_arg_was_passed('mask', args, kwargs): - return self._get_call_arg_value('mask', args, kwargs) - - if not self._should_compute_mask: - return None - - input_masks = [getattr(t, '_keras_mask', None) for t in input_list] - if all(mask is None for mask in input_masks): - return None - - # Only do expensive `nest` operation when masking is actually being used. - return nest.pack_sequence_as(inputs, input_masks) + def _get_input_masks(self, inputs, input_list, args, kwargs): + if (not self._expects_mask_arg and not self.supports_masking and + not self._compute_mask_overridden): + # Input masks only need to be retrieved if they are needed for `call` + # or `compute_mask`. + input_masks = None + implicit_mask = False + elif self._call_arg_was_passed('mask', args, kwargs): + input_masks = self._get_call_arg_value('mask', args, kwargs) + implicit_mask = False + else: + input_masks = [getattr(t, '_keras_mask', None) for t in input_list] + if all(mask is None for mask in input_masks): + input_masks = None + implicit_mask = False + else: + # Only do expensive `nest` op when masking is actually being used. + input_masks = nest.pack_sequence_as(inputs, input_masks) + implicit_mask = True + return input_masks, implicit_mask def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False): # Performance optimization: do no work in most common case. @@ -2751,12 +2757,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector): def _call_accepts_kwargs(self): return self._call_full_argspec.varkw is not None - @property - @tracking.cached_per_instance - def _should_compute_mask(self): - return ('mask' in self._call_fn_args or - getattr(self, 'compute_mask', None) is not None) - @property def _eager_losses(self): # A list of loss values containing activity regularizers and losses