Reduce Layer.__call__ overhead by ~5%.

Layer._call_arg_was_passed now has a shortcut for the common case.

PiperOrigin-RevId: 312767781
Change-Id: I97c926cf266e814f2d75c2beac63023faa715b7d
This commit is contained in:
Thomas O'Malley 2020-05-21 17:00:17 -07:00 committed by TensorFlower Gardener
parent 50dc3262ea
commit 27d373215c
1 changed files with 5 additions and 3 deletions

View File

@ -2308,15 +2308,17 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
return input_masks
def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
# Performance optimization: do no work in most common case.
if not args and not kwargs:
return False
if arg_name in kwargs:
return True
call_fn_args = self._call_fn_args
if not inputs_in_args:
# Ignore `inputs` arg.
call_fn_args = call_fn_args[1:]
if arg_name in dict(zip(call_fn_args, args)):
return True
return False
return arg_name in dict(zip(call_fn_args, args))
def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
if arg_name in kwargs: