Reduce Layer.__call__ overhead by ~5%.
Layer._call_arg_was_passed now has a shortcut for the common case. PiperOrigin-RevId: 312767781 Change-Id: I97c926cf266e814f2d75c2beac63023faa715b7d
This commit is contained in:
parent
50dc3262ea
commit
27d373215c
|
@ -2308,15 +2308,17 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||||
return input_masks
|
return input_masks
|
||||||
|
|
||||||
def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
|
def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
|
||||||
|
# Performance optimization: do no work in most common case.
|
||||||
|
if not args and not kwargs:
|
||||||
|
return False
|
||||||
|
|
||||||
if arg_name in kwargs:
|
if arg_name in kwargs:
|
||||||
return True
|
return True
|
||||||
call_fn_args = self._call_fn_args
|
call_fn_args = self._call_fn_args
|
||||||
if not inputs_in_args:
|
if not inputs_in_args:
|
||||||
# Ignore `inputs` arg.
|
# Ignore `inputs` arg.
|
||||||
call_fn_args = call_fn_args[1:]
|
call_fn_args = call_fn_args[1:]
|
||||||
if arg_name in dict(zip(call_fn_args, args)):
|
return arg_name in dict(zip(call_fn_args, args))
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
|
def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
|
||||||
if arg_name in kwargs:
|
if arg_name in kwargs:
|
||||||
|
|
Loading…
Reference in New Issue