Fork the tf function related util to keras.
PiperOrigin-RevId: 339983364 Change-Id: I877f2394f13b899ace0bb2893e6cb5f073b03458
This commit is contained in:
parent
0509af7e72
commit
05483cd409
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
@ -28,11 +29,12 @@ from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.mixed_precision import policy
|
||||
from tensorflow.python.keras.utils import tf_contextlib
|
||||
from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import function_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# Avoid breaking users who directly import this symbol from this file.
|
||||
@ -541,7 +543,7 @@ class Layer(base_layer.Layer):
|
||||
try:
|
||||
call_has_scope_arg = self._call_has_scope_arg
|
||||
except AttributeError:
|
||||
self._call_fn_args = function_utils.fn_args(self.call)
|
||||
self._call_fn_args = fn_args(self.call)
|
||||
self._call_has_scope_arg = 'scope' in self._call_fn_args
|
||||
call_has_scope_arg = self._call_has_scope_arg
|
||||
if call_has_scope_arg:
|
||||
@ -595,3 +597,35 @@ def _add_elements_to_collection(elements, collection_list):
|
||||
for element in elements:
|
||||
if id(element) not in collection_set:
|
||||
collection.append(element)
|
||||
|
||||
|
||||
def fn_args(fn):
|
||||
"""Get argument names for function-like object.
|
||||
|
||||
Args:
|
||||
fn: Function, or function-like object (e.g., result of `functools.partial`).
|
||||
|
||||
Returns:
|
||||
`tuple` of string argument names.
|
||||
|
||||
Raises:
|
||||
ValueError: if partial function has positionally bound arguments
|
||||
"""
|
||||
if isinstance(fn, functools.partial):
|
||||
args = fn_args(fn.func)
|
||||
args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
|
||||
else:
|
||||
if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
|
||||
fn = fn.__call__
|
||||
args = tf_inspect.getfullargspec(fn).args
|
||||
if is_bound_method(fn) and args:
|
||||
# If it's a bound method, it may or may not have a self/cls first
|
||||
# argument; for example, self could be captured in *args.
|
||||
# If it does have a positional argument, it is self/cls.
|
||||
args.pop(0)
|
||||
return tuple(args)
|
||||
|
||||
|
||||
def is_bound_method(fn):
|
||||
_, fn = tf_decorator.unwrap(fn)
|
||||
return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
|
||||
|
Loading…
Reference in New Issue
Block a user