Support functools.partial as callable object in tf_inspect.getargspec.
PiperOrigin-RevId: 197036874
This commit is contained in:
parent
9a815b422a
commit
7911247bef
@ -18,8 +18,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import functools
|
||||
import inspect as _inspect
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
||||
ArgSpec = _inspect.ArgSpec
|
||||
@ -43,16 +46,95 @@ def getargspec(object): # pylint: disable=redefined-builtin
|
||||
"""TFDecorator-aware replacement for inspect.getargspec.
|
||||
|
||||
Args:
|
||||
object: A callable, possibly decorated.
|
||||
object: A callable (function or partial function), possibly decorated.
|
||||
|
||||
Returns:
|
||||
The `ArgSpec` that describes the signature of the outermost decorator that
|
||||
changes the callable's signature. If the callable is not decorated,
|
||||
`inspect.getargspec()` will be called directly on the callable.
|
||||
|
||||
Raises:
|
||||
ValueError: When callable's function signature can not be expressed with
|
||||
ArgSpec.
|
||||
"""
|
||||
decorators, target = tf_decorator.unwrap(object)
|
||||
return next((d.decorator_argspec for d in decorators
|
||||
if d.decorator_argspec is not None), _inspect.getargspec(target))
|
||||
|
||||
def get_argspec_with_decorator(obj):
|
||||
decorators, target = tf_decorator.unwrap(obj)
|
||||
return next((d.decorator_argspec
|
||||
for d in decorators
|
||||
if d.decorator_argspec is not None),
|
||||
_inspect.getargspec(target))
|
||||
|
||||
if not isinstance(object, functools.partial):
|
||||
return get_argspec_with_decorator(object)
|
||||
|
||||
# When callable is a functools.partial object, we construct its ArgSpec with
|
||||
# following strategy:
|
||||
# - If callable partial contains default value for positional arguments (ie.
|
||||
# object.args), then final ArgSpec doesn't contain those positional arguments.
|
||||
# - If callable partial contains default value for keyword arguments (ie.
|
||||
# object.keywords), then we merge them with wrapped target. Default values
|
||||
# from callable partial takes precedence over those from wrapped target.
|
||||
#
|
||||
# However, there is a case where it is impossible to construct a valid
|
||||
# ArgSpec. Python requires arguments that have no default values must be
|
||||
# defined before those with default values. ArgSpec structure is only valid
|
||||
# when this presumption holds true because default values are expressed as a
|
||||
# tuple of values without keywords and they are always assumed to belong to
|
||||
# last K arguments where K is number of default values present.
|
||||
#
|
||||
# Since functools.partial can give default value to any argument, this
|
||||
# presumption may no longer hold in some cases. For example:
|
||||
#
|
||||
# def func(m, n):
|
||||
# return 2 * m + n
|
||||
# partialed = functools.partial(func, m=1)
|
||||
#
|
||||
# This example will result in m having a default value but n doesn't. This is
|
||||
# usually not allowed in Python and can not be expressed in ArgSpec correctly.
|
||||
#
|
||||
# Thus, we must detect cases like this by finding first argument with default
|
||||
# value and ensures all following arguments also have default values. When
|
||||
# this is not true, a ValueError is raised.
|
||||
|
||||
n_prune_args = len(object.args)
|
||||
partial_keywords = object.keywords or {}
|
||||
|
||||
args, varargs, keywords, defaults = get_argspec_with_decorator(object.func)
|
||||
|
||||
# Pruning first n_prune_args arguments.
|
||||
args = args[n_prune_args:]
|
||||
|
||||
# Partial function may give default value to any argument, therefore length
|
||||
# of default value list must be len(args) to allow each argument to
|
||||
# potentially be given a default value.
|
||||
all_defaults = [None] * len(args)
|
||||
if defaults:
|
||||
all_defaults[-len(defaults):] = defaults
|
||||
|
||||
# Fill in default values provided by partial function in all_defaults.
|
||||
for kw, default in six.iteritems(partial_keywords):
|
||||
idx = args.index(kw)
|
||||
all_defaults[idx] = default
|
||||
|
||||
# Find first argument with default value set.
|
||||
first_default = next((idx for idx, x in enumerate(all_defaults) if x), None)
|
||||
|
||||
# If no default values are found, return ArgSpec with defaults=None.
|
||||
if first_default is None:
|
||||
return ArgSpec(args, varargs, keywords, None)
|
||||
|
||||
# Checks if all arguments have default value set after first one.
|
||||
invalid_default_values = [
|
||||
args[i] for i, j in enumerate(all_defaults) if not j and i > first_default
|
||||
]
|
||||
|
||||
if invalid_default_values:
|
||||
raise ValueError('Some arguments %s do not have default value, but they '
|
||||
'are positioned after those with default values. This can '
|
||||
'not be expressed with ArgSpec.' % invalid_default_values)
|
||||
|
||||
return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:]))
|
||||
|
||||
|
||||
def getfullargspec(obj): # pylint: disable=redefined-builtin
|
||||
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
@ -109,6 +110,141 @@ class TfInspectTest(test.TestCase):
|
||||
outer_argspec)
|
||||
self.assertEqual(outer_argspec, tf_inspect.getargspec(outer_decorator))
|
||||
|
||||
def testGetArgSpecOnPartialPositionalArgumentOnly(self):
|
||||
"""Tests getargspec on partial function with only positional arguments."""
|
||||
|
||||
def func(m, n):
|
||||
return 2 * m + n
|
||||
|
||||
partial_func = functools.partial(func, 7)
|
||||
argspec = tf_inspect.ArgSpec(
|
||||
args=['n'], varargs=None, keywords=None, defaults=None)
|
||||
|
||||
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||
|
||||
def testGetArgSpecOnPartialInvalidArgspec(self):
|
||||
"""Tests getargspec on partial function that doesn't have valid argspec."""
|
||||
|
||||
def func(m, n, l, k=4):
|
||||
return 2 * m + l + n * k
|
||||
|
||||
partial_func = functools.partial(func, n=7)
|
||||
|
||||
exception_message = (r"Some arguments \['l'\] do not have default value, "
|
||||
"but they are positioned after those with default "
|
||||
"values. This can not be expressed with ArgSpec.")
|
||||
with self.assertRaisesRegexp(ValueError, exception_message):
|
||||
tf_inspect.getargspec(partial_func)
|
||||
|
||||
def testGetArgSpecOnPartialValidArgspec(self):
|
||||
"""Tests getargspec on partial function with valid argspec."""
|
||||
|
||||
def func(m, n, l, k=4):
|
||||
return 2 * m + l + n * k
|
||||
|
||||
partial_func = functools.partial(func, n=7, l=2)
|
||||
argspec = tf_inspect.ArgSpec(
|
||||
args=['m', 'n', 'l', 'k'],
|
||||
varargs=None,
|
||||
keywords=None,
|
||||
defaults=(7, 2, 4))
|
||||
|
||||
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||
|
||||
def testGetArgSpecOnPartialNoArgumentsLeft(self):
|
||||
"""Tests getargspec on partial function that prunes all arguments."""
|
||||
|
||||
def func(m, n):
|
||||
return 2 * m + n
|
||||
|
||||
partial_func = functools.partial(func, 7, 10)
|
||||
argspec = tf_inspect.ArgSpec(
|
||||
args=[], varargs=None, keywords=None, defaults=None)
|
||||
|
||||
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||
|
||||
def testGetArgSpecOnPartialKeywordArgument(self):
|
||||
"""Tests getargspec on partial function that prunes some arguments."""
|
||||
|
||||
def func(m, n):
|
||||
return 2 * m + n
|
||||
|
||||
partial_func = functools.partial(func, n=7)
|
||||
argspec = tf_inspect.ArgSpec(
|
||||
args=['m', 'n'], varargs=None, keywords=None, defaults=(7,))
|
||||
|
||||
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||
|
||||
def testGetArgSpecOnPartialKeywordArgumentWithDefaultValue(self):
|
||||
"""Tests getargspec on partial function that prunes argument by keyword."""
|
||||
|
||||
def func(m=1, n=2):
|
||||
return 2 * m + n
|
||||
|
||||
partial_func = functools.partial(func, n=7)
|
||||
argspec = tf_inspect.ArgSpec(
|
||||
args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7))
|
||||
|
||||
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||
|
||||
def testGetArgSpecOnPartialWithVarargs(self):
|
||||
"""Tests getargspec on partial function with variable arguments."""
|
||||
|
||||
def func(m, *arg):
|
||||
return m + len(arg)
|
||||
|
||||
partial_func = functools.partial(func, 7, 8)
|
||||
argspec = tf_inspect.ArgSpec(
|
||||
args=[], varargs='arg', keywords=None, defaults=None)
|
||||
|
||||
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||
|
||||
def testGetArgSpecOnPartialWithVarkwargs(self):
|
||||
"""Tests getargspec on partial function with variable keyword arguments."""
|
||||
|
||||
def func(m, n, **kwarg):
|
||||
return m * n + len(kwarg)
|
||||
|
||||
partial_func = functools.partial(func, 7)
|
||||
argspec = tf_inspect.ArgSpec(
|
||||
args=['n'], varargs=None, keywords='kwarg', defaults=None)
|
||||
|
||||
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||
|
||||
def testGetArgSpecOnPartialWithDecorator(self):
|
||||
"""Tests getargspec on decorated partial function."""
|
||||
|
||||
@test_decorator('decorator')
|
||||
def func(m=1, n=2):
|
||||
return 2 * m + n
|
||||
|
||||
partial_func = functools.partial(func, n=7)
|
||||
argspec = tf_inspect.ArgSpec(
|
||||
args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7))
|
||||
|
||||
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||
|
||||
def testGetArgSpecOnPartialWithDecoratorThatChangesArgspec(self):
|
||||
"""Tests getargspec on partial function with decorated argspec."""
|
||||
|
||||
argspec = tf_inspect.ArgSpec(
|
||||
args=['a', 'b', 'c'],
|
||||
varargs=None,
|
||||
keywords=None,
|
||||
defaults=(1, 'hello'))
|
||||
decorator = tf_decorator.TFDecorator('', test_undecorated_function, '',
|
||||
argspec)
|
||||
partial_argspec = tf_inspect.ArgSpec(
|
||||
args=['a', 'b', 'c'],
|
||||
varargs=None,
|
||||
keywords=None,
|
||||
defaults=(2, 1, 'hello'))
|
||||
partial_with_decorator = functools.partial(decorator, a=2)
|
||||
|
||||
self.assertEqual(argspec, tf_inspect.getargspec(decorator))
|
||||
self.assertEqual(partial_argspec,
|
||||
tf_inspect.getargspec(partial_with_decorator))
|
||||
|
||||
def testGetDoc(self):
|
||||
self.assertEqual('Test Decorated Function With Defaults Docstring.',
|
||||
tf_inspect.getdoc(test_decorated_function_with_defaults))
|
||||
|
Loading…
Reference in New Issue
Block a user