Support functools.partial as callable object in tf_inspect.getargspec.
PiperOrigin-RevId: 197036874
This commit is contained in:
parent
9a815b422a
commit
7911247bef
@ -18,8 +18,11 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
import functools
|
||||||
import inspect as _inspect
|
import inspect as _inspect
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
|
|
||||||
ArgSpec = _inspect.ArgSpec
|
ArgSpec = _inspect.ArgSpec
|
||||||
@ -43,16 +46,95 @@ def getargspec(object): # pylint: disable=redefined-builtin
|
|||||||
"""TFDecorator-aware replacement for inspect.getargspec.
|
"""TFDecorator-aware replacement for inspect.getargspec.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
object: A callable, possibly decorated.
|
object: A callable (function or partial function), possibly decorated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The `ArgSpec` that describes the signature of the outermost decorator that
|
The `ArgSpec` that describes the signature of the outermost decorator that
|
||||||
changes the callable's signature. If the callable is not decorated,
|
changes the callable's signature. If the callable is not decorated,
|
||||||
`inspect.getargspec()` will be called directly on the callable.
|
`inspect.getargspec()` will be called directly on the callable.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: When callable's function signature can not be expressed with
|
||||||
|
ArgSpec.
|
||||||
"""
|
"""
|
||||||
decorators, target = tf_decorator.unwrap(object)
|
|
||||||
return next((d.decorator_argspec for d in decorators
|
def get_argspec_with_decorator(obj):
|
||||||
if d.decorator_argspec is not None), _inspect.getargspec(target))
|
decorators, target = tf_decorator.unwrap(obj)
|
||||||
|
return next((d.decorator_argspec
|
||||||
|
for d in decorators
|
||||||
|
if d.decorator_argspec is not None),
|
||||||
|
_inspect.getargspec(target))
|
||||||
|
|
||||||
|
if not isinstance(object, functools.partial):
|
||||||
|
return get_argspec_with_decorator(object)
|
||||||
|
|
||||||
|
# When callable is a functools.partial object, we construct its ArgSpec with
|
||||||
|
# following strategy:
|
||||||
|
# - If callable partial contains default value for positional arguments (ie.
|
||||||
|
# object.args), then final ArgSpec doesn't contain those positional arguments.
|
||||||
|
# - If callable partial contains default value for keyword arguments (ie.
|
||||||
|
# object.keywords), then we merge them with wrapped target. Default values
|
||||||
|
# from callable partial takes precedence over those from wrapped target.
|
||||||
|
#
|
||||||
|
# However, there is a case where it is impossible to construct a valid
|
||||||
|
# ArgSpec. Python requires arguments that have no default values must be
|
||||||
|
# defined before those with default values. ArgSpec structure is only valid
|
||||||
|
# when this presumption holds true because default values are expressed as a
|
||||||
|
# tuple of values without keywords and they are always assumed to belong to
|
||||||
|
# last K arguments where K is number of default values present.
|
||||||
|
#
|
||||||
|
# Since functools.partial can give default value to any argument, this
|
||||||
|
# presumption may no longer hold in some cases. For example:
|
||||||
|
#
|
||||||
|
# def func(m, n):
|
||||||
|
# return 2 * m + n
|
||||||
|
# partialed = functools.partial(func, m=1)
|
||||||
|
#
|
||||||
|
# This example will result in m having a default value but n doesn't. This is
|
||||||
|
# usually not allowed in Python and can not be expressed in ArgSpec correctly.
|
||||||
|
#
|
||||||
|
# Thus, we must detect cases like this by finding first argument with default
|
||||||
|
# value and ensures all following arguments also have default values. When
|
||||||
|
# this is not true, a ValueError is raised.
|
||||||
|
|
||||||
|
n_prune_args = len(object.args)
|
||||||
|
partial_keywords = object.keywords or {}
|
||||||
|
|
||||||
|
args, varargs, keywords, defaults = get_argspec_with_decorator(object.func)
|
||||||
|
|
||||||
|
# Pruning first n_prune_args arguments.
|
||||||
|
args = args[n_prune_args:]
|
||||||
|
|
||||||
|
# Partial function may give default value to any argument, therefore length
|
||||||
|
# of default value list must be len(args) to allow each argument to
|
||||||
|
# potentially be given a default value.
|
||||||
|
all_defaults = [None] * len(args)
|
||||||
|
if defaults:
|
||||||
|
all_defaults[-len(defaults):] = defaults
|
||||||
|
|
||||||
|
# Fill in default values provided by partial function in all_defaults.
|
||||||
|
for kw, default in six.iteritems(partial_keywords):
|
||||||
|
idx = args.index(kw)
|
||||||
|
all_defaults[idx] = default
|
||||||
|
|
||||||
|
# Find first argument with default value set.
|
||||||
|
first_default = next((idx for idx, x in enumerate(all_defaults) if x), None)
|
||||||
|
|
||||||
|
# If no default values are found, return ArgSpec with defaults=None.
|
||||||
|
if first_default is None:
|
||||||
|
return ArgSpec(args, varargs, keywords, None)
|
||||||
|
|
||||||
|
# Checks if all arguments have default value set after first one.
|
||||||
|
invalid_default_values = [
|
||||||
|
args[i] for i, j in enumerate(all_defaults) if not j and i > first_default
|
||||||
|
]
|
||||||
|
|
||||||
|
if invalid_default_values:
|
||||||
|
raise ValueError('Some arguments %s do not have default value, but they '
|
||||||
|
'are positioned after those with default values. This can '
|
||||||
|
'not be expressed with ArgSpec.' % invalid_default_values)
|
||||||
|
|
||||||
|
return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:]))
|
||||||
|
|
||||||
|
|
||||||
def getfullargspec(obj): # pylint: disable=redefined-builtin
|
def getfullargspec(obj): # pylint: disable=redefined-builtin
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -109,6 +110,141 @@ class TfInspectTest(test.TestCase):
|
|||||||
outer_argspec)
|
outer_argspec)
|
||||||
self.assertEqual(outer_argspec, tf_inspect.getargspec(outer_decorator))
|
self.assertEqual(outer_argspec, tf_inspect.getargspec(outer_decorator))
|
||||||
|
|
||||||
|
def testGetArgSpecOnPartialPositionalArgumentOnly(self):
|
||||||
|
"""Tests getargspec on partial function with only positional arguments."""
|
||||||
|
|
||||||
|
def func(m, n):
|
||||||
|
return 2 * m + n
|
||||||
|
|
||||||
|
partial_func = functools.partial(func, 7)
|
||||||
|
argspec = tf_inspect.ArgSpec(
|
||||||
|
args=['n'], varargs=None, keywords=None, defaults=None)
|
||||||
|
|
||||||
|
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||||
|
|
||||||
|
def testGetArgSpecOnPartialInvalidArgspec(self):
|
||||||
|
"""Tests getargspec on partial function that doesn't have valid argspec."""
|
||||||
|
|
||||||
|
def func(m, n, l, k=4):
|
||||||
|
return 2 * m + l + n * k
|
||||||
|
|
||||||
|
partial_func = functools.partial(func, n=7)
|
||||||
|
|
||||||
|
exception_message = (r"Some arguments \['l'\] do not have default value, "
|
||||||
|
"but they are positioned after those with default "
|
||||||
|
"values. This can not be expressed with ArgSpec.")
|
||||||
|
with self.assertRaisesRegexp(ValueError, exception_message):
|
||||||
|
tf_inspect.getargspec(partial_func)
|
||||||
|
|
||||||
|
def testGetArgSpecOnPartialValidArgspec(self):
|
||||||
|
"""Tests getargspec on partial function with valid argspec."""
|
||||||
|
|
||||||
|
def func(m, n, l, k=4):
|
||||||
|
return 2 * m + l + n * k
|
||||||
|
|
||||||
|
partial_func = functools.partial(func, n=7, l=2)
|
||||||
|
argspec = tf_inspect.ArgSpec(
|
||||||
|
args=['m', 'n', 'l', 'k'],
|
||||||
|
varargs=None,
|
||||||
|
keywords=None,
|
||||||
|
defaults=(7, 2, 4))
|
||||||
|
|
||||||
|
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||||
|
|
||||||
|
def testGetArgSpecOnPartialNoArgumentsLeft(self):
|
||||||
|
"""Tests getargspec on partial function that prunes all arguments."""
|
||||||
|
|
||||||
|
def func(m, n):
|
||||||
|
return 2 * m + n
|
||||||
|
|
||||||
|
partial_func = functools.partial(func, 7, 10)
|
||||||
|
argspec = tf_inspect.ArgSpec(
|
||||||
|
args=[], varargs=None, keywords=None, defaults=None)
|
||||||
|
|
||||||
|
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||||
|
|
||||||
|
def testGetArgSpecOnPartialKeywordArgument(self):
|
||||||
|
"""Tests getargspec on partial function that prunes some arguments."""
|
||||||
|
|
||||||
|
def func(m, n):
|
||||||
|
return 2 * m + n
|
||||||
|
|
||||||
|
partial_func = functools.partial(func, n=7)
|
||||||
|
argspec = tf_inspect.ArgSpec(
|
||||||
|
args=['m', 'n'], varargs=None, keywords=None, defaults=(7,))
|
||||||
|
|
||||||
|
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||||
|
|
||||||
|
def testGetArgSpecOnPartialKeywordArgumentWithDefaultValue(self):
|
||||||
|
"""Tests getargspec on partial function that prunes argument by keyword."""
|
||||||
|
|
||||||
|
def func(m=1, n=2):
|
||||||
|
return 2 * m + n
|
||||||
|
|
||||||
|
partial_func = functools.partial(func, n=7)
|
||||||
|
argspec = tf_inspect.ArgSpec(
|
||||||
|
args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7))
|
||||||
|
|
||||||
|
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||||
|
|
||||||
|
def testGetArgSpecOnPartialWithVarargs(self):
|
||||||
|
"""Tests getargspec on partial function with variable arguments."""
|
||||||
|
|
||||||
|
def func(m, *arg):
|
||||||
|
return m + len(arg)
|
||||||
|
|
||||||
|
partial_func = functools.partial(func, 7, 8)
|
||||||
|
argspec = tf_inspect.ArgSpec(
|
||||||
|
args=[], varargs='arg', keywords=None, defaults=None)
|
||||||
|
|
||||||
|
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||||
|
|
||||||
|
def testGetArgSpecOnPartialWithVarkwargs(self):
|
||||||
|
"""Tests getargspec on partial function with variable keyword arguments."""
|
||||||
|
|
||||||
|
def func(m, n, **kwarg):
|
||||||
|
return m * n + len(kwarg)
|
||||||
|
|
||||||
|
partial_func = functools.partial(func, 7)
|
||||||
|
argspec = tf_inspect.ArgSpec(
|
||||||
|
args=['n'], varargs=None, keywords='kwarg', defaults=None)
|
||||||
|
|
||||||
|
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||||
|
|
||||||
|
def testGetArgSpecOnPartialWithDecorator(self):
|
||||||
|
"""Tests getargspec on decorated partial function."""
|
||||||
|
|
||||||
|
@test_decorator('decorator')
|
||||||
|
def func(m=1, n=2):
|
||||||
|
return 2 * m + n
|
||||||
|
|
||||||
|
partial_func = functools.partial(func, n=7)
|
||||||
|
argspec = tf_inspect.ArgSpec(
|
||||||
|
args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7))
|
||||||
|
|
||||||
|
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||||
|
|
||||||
|
def testGetArgSpecOnPartialWithDecoratorThatChangesArgspec(self):
|
||||||
|
"""Tests getargspec on partial function with decorated argspec."""
|
||||||
|
|
||||||
|
argspec = tf_inspect.ArgSpec(
|
||||||
|
args=['a', 'b', 'c'],
|
||||||
|
varargs=None,
|
||||||
|
keywords=None,
|
||||||
|
defaults=(1, 'hello'))
|
||||||
|
decorator = tf_decorator.TFDecorator('', test_undecorated_function, '',
|
||||||
|
argspec)
|
||||||
|
partial_argspec = tf_inspect.ArgSpec(
|
||||||
|
args=['a', 'b', 'c'],
|
||||||
|
varargs=None,
|
||||||
|
keywords=None,
|
||||||
|
defaults=(2, 1, 'hello'))
|
||||||
|
partial_with_decorator = functools.partial(decorator, a=2)
|
||||||
|
|
||||||
|
self.assertEqual(argspec, tf_inspect.getargspec(decorator))
|
||||||
|
self.assertEqual(partial_argspec,
|
||||||
|
tf_inspect.getargspec(partial_with_decorator))
|
||||||
|
|
||||||
def testGetDoc(self):
|
def testGetDoc(self):
|
||||||
self.assertEqual('Test Decorated Function With Defaults Docstring.',
|
self.assertEqual('Test Decorated Function With Defaults Docstring.',
|
||||||
tf_inspect.getdoc(test_decorated_function_with_defaults))
|
tf_inspect.getdoc(test_decorated_function_with_defaults))
|
||||||
|
Loading…
Reference in New Issue
Block a user