From 27d373215c554bdbccc654f14b1f05738ab381d1 Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Thu, 21 May 2020 17:00:17 -0700 Subject: [PATCH] Reduce Layer.__call__ overhead by ~5%. Layer._call_arg_was_passed now has a shortcut for the common case. PiperOrigin-RevId: 312767781 Change-Id: I97c926cf266e814f2d75c2beac63023faa715b7d --- tensorflow/python/keras/engine/base_layer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 0421772a75a..53d8cc5ab34 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2308,15 +2308,17 @@ class Layer(module.Module, version_utils.LayerVersionSelector): return input_masks def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False): + # Performance optimization: do no work in most common case. + if not args and not kwargs: + return False + if arg_name in kwargs: return True call_fn_args = self._call_fn_args if not inputs_in_args: # Ignore `inputs` arg. call_fn_args = call_fn_args[1:] - if arg_name in dict(zip(call_fn_args, args)): - return True - return False + return arg_name in dict(zip(call_fn_args, args)) def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False): if arg_name in kwargs: