Reduce Layer.__call__ overhead by ~5%.
Layer._call_arg_was_passed now has a shortcut for the common case. PiperOrigin-RevId: 312767781 Change-Id: I97c926cf266e814f2d75c2beac63023faa715b7d
This commit is contained in:
parent
50dc3262ea
commit
27d373215c
|
@ -2308,15 +2308,17 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
return input_masks
|
||||
|
||||
def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
|
||||
# Performance optimization: do no work in most common case.
|
||||
if not args and not kwargs:
|
||||
return False
|
||||
|
||||
if arg_name in kwargs:
|
||||
return True
|
||||
call_fn_args = self._call_fn_args
|
||||
if not inputs_in_args:
|
||||
# Ignore `inputs` arg.
|
||||
call_fn_args = call_fn_args[1:]
|
||||
if arg_name in dict(zip(call_fn_args, args)):
|
||||
return True
|
||||
return False
|
||||
return arg_name in dict(zip(call_fn_args, args))
|
||||
|
||||
def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
|
||||
if arg_name in kwargs:
|
||||
|
|
Loading…
Reference in New Issue