Support functools.partial as callable object in tf_inspect.getargspec.

PiperOrigin-RevId: 197036874
This commit is contained in:
A. Unique TensorFlower 2018-05-17 12:55:02 -07:00 committed by TensorFlower Gardener
parent 9a815b422a
commit 7911247bef
2 changed files with 222 additions and 4 deletions

View File

@ -18,8 +18,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import namedtuple from collections import namedtuple
import functools
import inspect as _inspect import inspect as _inspect
import six
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
ArgSpec = _inspect.ArgSpec ArgSpec = _inspect.ArgSpec
@ -43,16 +46,95 @@ def getargspec(object): # pylint: disable=redefined-builtin
"""TFDecorator-aware replacement for inspect.getargspec. """TFDecorator-aware replacement for inspect.getargspec.
Args: Args:
object: A callable, possibly decorated. object: A callable (function or partial function), possibly decorated.
Returns: Returns:
The `ArgSpec` that describes the signature of the outermost decorator that The `ArgSpec` that describes the signature of the outermost decorator that
changes the callable's signature. If the callable is not decorated, changes the callable's signature. If the callable is not decorated,
`inspect.getargspec()` will be called directly on the callable. `inspect.getargspec()` will be called directly on the callable.
Raises:
ValueError: When callable's function signature can not be expressed with
ArgSpec.
""" """
decorators, target = tf_decorator.unwrap(object)
return next((d.decorator_argspec for d in decorators def get_argspec_with_decorator(obj):
if d.decorator_argspec is not None), _inspect.getargspec(target)) decorators, target = tf_decorator.unwrap(obj)
return next((d.decorator_argspec
for d in decorators
if d.decorator_argspec is not None),
_inspect.getargspec(target))
if not isinstance(object, functools.partial):
return get_argspec_with_decorator(object)
# When callable is a functools.partial object, we construct its ArgSpec with
# following strategy:
# - If callable partial contains default value for positional arguments (ie.
# object.args), then final ArgSpec doesn't contain those positional arguments.
# - If callable partial contains default value for keyword arguments (ie.
# object.keywords), then we merge them with wrapped target. Default values
# from callable partial takes precedence over those from wrapped target.
#
# However, there is a case where it is impossible to construct a valid
# ArgSpec. Python requires arguments that have no default values must be
# defined before those with default values. ArgSpec structure is only valid
# when this presumption holds true because default values are expressed as a
# tuple of values without keywords and they are always assumed to belong to
# last K arguments where K is number of default values present.
#
# Since functools.partial can give default value to any argument, this
# presumption may no longer hold in some cases. For example:
#
# def func(m, n):
# return 2 * m + n
# partialed = functools.partial(func, m=1)
#
# This example will result in m having a default value but n doesn't. This is
# usually not allowed in Python and can not be expressed in ArgSpec correctly.
#
# Thus, we must detect cases like this by finding first argument with default
# value and ensures all following arguments also have default values. When
# this is not true, a ValueError is raised.
n_prune_args = len(object.args)
partial_keywords = object.keywords or {}
args, varargs, keywords, defaults = get_argspec_with_decorator(object.func)
# Pruning first n_prune_args arguments.
args = args[n_prune_args:]
# Partial function may give default value to any argument, therefore length
# of default value list must be len(args) to allow each argument to
# potentially be given a default value.
all_defaults = [None] * len(args)
if defaults:
all_defaults[-len(defaults):] = defaults
# Fill in default values provided by partial function in all_defaults.
for kw, default in six.iteritems(partial_keywords):
idx = args.index(kw)
all_defaults[idx] = default
# Find first argument with default value set.
first_default = next((idx for idx, x in enumerate(all_defaults) if x), None)
# If no default values are found, return ArgSpec with defaults=None.
if first_default is None:
return ArgSpec(args, varargs, keywords, None)
# Checks if all arguments have default value set after first one.
invalid_default_values = [
args[i] for i, j in enumerate(all_defaults) if not j and i > first_default
]
if invalid_default_values:
raise ValueError('Some arguments %s do not have default value, but they '
'are positioned after those with default values. This can '
'not be expressed with ArgSpec.' % invalid_default_values)
return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:]))
def getfullargspec(obj): # pylint: disable=redefined-builtin def getfullargspec(obj): # pylint: disable=redefined-builtin

View File

@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import inspect import inspect
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -109,6 +110,141 @@ class TfInspectTest(test.TestCase):
outer_argspec) outer_argspec)
self.assertEqual(outer_argspec, tf_inspect.getargspec(outer_decorator)) self.assertEqual(outer_argspec, tf_inspect.getargspec(outer_decorator))
def testGetArgSpecOnPartialPositionalArgumentOnly(self):
"""Tests getargspec on partial function with only positional arguments."""
def func(m, n):
return 2 * m + n
partial_func = functools.partial(func, 7)
argspec = tf_inspect.ArgSpec(
args=['n'], varargs=None, keywords=None, defaults=None)
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialInvalidArgspec(self):
"""Tests getargspec on partial function that doesn't have valid argspec."""
def func(m, n, l, k=4):
return 2 * m + l + n * k
partial_func = functools.partial(func, n=7)
exception_message = (r"Some arguments \['l'\] do not have default value, "
"but they are positioned after those with default "
"values. This can not be expressed with ArgSpec.")
with self.assertRaisesRegexp(ValueError, exception_message):
tf_inspect.getargspec(partial_func)
def testGetArgSpecOnPartialValidArgspec(self):
"""Tests getargspec on partial function with valid argspec."""
def func(m, n, l, k=4):
return 2 * m + l + n * k
partial_func = functools.partial(func, n=7, l=2)
argspec = tf_inspect.ArgSpec(
args=['m', 'n', 'l', 'k'],
varargs=None,
keywords=None,
defaults=(7, 2, 4))
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialNoArgumentsLeft(self):
"""Tests getargspec on partial function that prunes all arguments."""
def func(m, n):
return 2 * m + n
partial_func = functools.partial(func, 7, 10)
argspec = tf_inspect.ArgSpec(
args=[], varargs=None, keywords=None, defaults=None)
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialKeywordArgument(self):
"""Tests getargspec on partial function that prunes some arguments."""
def func(m, n):
return 2 * m + n
partial_func = functools.partial(func, n=7)
argspec = tf_inspect.ArgSpec(
args=['m', 'n'], varargs=None, keywords=None, defaults=(7,))
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialKeywordArgumentWithDefaultValue(self):
"""Tests getargspec on partial function that prunes argument by keyword."""
def func(m=1, n=2):
return 2 * m + n
partial_func = functools.partial(func, n=7)
argspec = tf_inspect.ArgSpec(
args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7))
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialWithVarargs(self):
"""Tests getargspec on partial function with variable arguments."""
def func(m, *arg):
return m + len(arg)
partial_func = functools.partial(func, 7, 8)
argspec = tf_inspect.ArgSpec(
args=[], varargs='arg', keywords=None, defaults=None)
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialWithVarkwargs(self):
"""Tests getargspec on partial function with variable keyword arguments."""
def func(m, n, **kwarg):
return m * n + len(kwarg)
partial_func = functools.partial(func, 7)
argspec = tf_inspect.ArgSpec(
args=['n'], varargs=None, keywords='kwarg', defaults=None)
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialWithDecorator(self):
"""Tests getargspec on decorated partial function."""
@test_decorator('decorator')
def func(m=1, n=2):
return 2 * m + n
partial_func = functools.partial(func, n=7)
argspec = tf_inspect.ArgSpec(
args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7))
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialWithDecoratorThatChangesArgspec(self):
"""Tests getargspec on partial function with decorated argspec."""
argspec = tf_inspect.ArgSpec(
args=['a', 'b', 'c'],
varargs=None,
keywords=None,
defaults=(1, 'hello'))
decorator = tf_decorator.TFDecorator('', test_undecorated_function, '',
argspec)
partial_argspec = tf_inspect.ArgSpec(
args=['a', 'b', 'c'],
varargs=None,
keywords=None,
defaults=(2, 1, 'hello'))
partial_with_decorator = functools.partial(decorator, a=2)
self.assertEqual(argspec, tf_inspect.getargspec(decorator))
self.assertEqual(partial_argspec,
tf_inspect.getargspec(partial_with_decorator))
def testGetDoc(self): def testGetDoc(self):
self.assertEqual('Test Decorated Function With Defaults Docstring.', self.assertEqual('Test Decorated Function With Defaults Docstring.',
tf_inspect.getdoc(test_decorated_function_with_defaults)) tf_inspect.getdoc(test_decorated_function_with_defaults))