From 7911247beffe993a5c96a5f26b534e2e8eac7cbf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 May 2018 12:55:02 -0700 Subject: [PATCH] Support functools.partial as callable object in tf_inspect.getargspec. PiperOrigin-RevId: 197036874 --- tensorflow/python/util/tf_inspect.py | 90 +++++++++++++- tensorflow/python/util/tf_inspect_test.py | 136 ++++++++++++++++++++++ 2 files changed, 222 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py index 663036de8a0..33b389c8c48 100644 --- a/tensorflow/python/util/tf_inspect.py +++ b/tensorflow/python/util/tf_inspect.py @@ -18,8 +18,11 @@ from __future__ import division from __future__ import print_function from collections import namedtuple +import functools import inspect as _inspect +import six + from tensorflow.python.util import tf_decorator ArgSpec = _inspect.ArgSpec @@ -43,16 +46,95 @@ def getargspec(object): # pylint: disable=redefined-builtin """TFDecorator-aware replacement for inspect.getargspec. Args: - object: A callable, possibly decorated. + object: A callable (function or partial function), possibly decorated. Returns: The `ArgSpec` that describes the signature of the outermost decorator that changes the callable's signature. If the callable is not decorated, `inspect.getargspec()` will be called directly on the callable. + + Raises: + ValueError: When callable's function signature can not be expressed with + ArgSpec. """ - decorators, target = tf_decorator.unwrap(object) - return next((d.decorator_argspec for d in decorators - if d.decorator_argspec is not None), _inspect.getargspec(target)) + + def get_argspec_with_decorator(obj): + decorators, target = tf_decorator.unwrap(obj) + return next((d.decorator_argspec + for d in decorators + if d.decorator_argspec is not None), + _inspect.getargspec(target)) + + if not isinstance(object, functools.partial): + return get_argspec_with_decorator(object) + + # When callable is a functools.partial object, we construct its ArgSpec with + # following strategy: + # - If callable partial contains default value for positional arguments (ie. + # object.args), then final ArgSpec doesn't contain those positional arguments. + # - If callable partial contains default value for keyword arguments (ie. + # object.keywords), then we merge them with wrapped target. Default values + # from callable partial takes precedence over those from wrapped target. + # + # However, there is a case where it is impossible to construct a valid + # ArgSpec. Python requires arguments that have no default values must be + # defined before those with default values. ArgSpec structure is only valid + # when this presumption holds true because default values are expressed as a + # tuple of values without keywords and they are always assumed to belong to + # last K arguments where K is number of default values present. + # + # Since functools.partial can give default value to any argument, this + # presumption may no longer hold in some cases. For example: + # + # def func(m, n): + # return 2 * m + n + # partialed = functools.partial(func, m=1) + # + # This example will result in m having a default value but n doesn't. This is + # usually not allowed in Python and can not be expressed in ArgSpec correctly. + # + # Thus, we must detect cases like this by finding first argument with default + # value and ensures all following arguments also have default values. When + # this is not true, a ValueError is raised. + + n_prune_args = len(object.args) + partial_keywords = object.keywords or {} + + args, varargs, keywords, defaults = get_argspec_with_decorator(object.func) + + # Pruning first n_prune_args arguments. + args = args[n_prune_args:] + + # Partial function may give default value to any argument, therefore length + # of default value list must be len(args) to allow each argument to + # potentially be given a default value. + all_defaults = [None] * len(args) + if defaults: + all_defaults[-len(defaults):] = defaults + + # Fill in default values provided by partial function in all_defaults. + for kw, default in six.iteritems(partial_keywords): + idx = args.index(kw) + all_defaults[idx] = default + + # Find first argument with default value set. + first_default = next((idx for idx, x in enumerate(all_defaults) if x), None) + + # If no default values are found, return ArgSpec with defaults=None. + if first_default is None: + return ArgSpec(args, varargs, keywords, None) + + # Checks if all arguments have default value set after first one. + invalid_default_values = [ + args[i] for i, j in enumerate(all_defaults) if not j and i > first_default + ] + + if invalid_default_values: + raise ValueError('Some arguments %s do not have default value, but they ' + 'are positioned after those with default values. This can ' + 'not be expressed with ArgSpec.' % invalid_default_values) + + return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:])) def getfullargspec(obj): # pylint: disable=redefined-builtin diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py index 129408449eb..325131c4f47 100644 --- a/tensorflow/python/util/tf_inspect_test.py +++ b/tensorflow/python/util/tf_inspect_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import inspect from tensorflow.python.platform import test @@ -109,6 +110,141 @@ class TfInspectTest(test.TestCase): outer_argspec) self.assertEqual(outer_argspec, tf_inspect.getargspec(outer_decorator)) + def testGetArgSpecOnPartialPositionalArgumentOnly(self): + """Tests getargspec on partial function with only positional arguments.""" + + def func(m, n): + return 2 * m + n + + partial_func = functools.partial(func, 7) + argspec = tf_inspect.ArgSpec( + args=['n'], varargs=None, keywords=None, defaults=None) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialInvalidArgspec(self): + """Tests getargspec on partial function that doesn't have valid argspec.""" + + def func(m, n, l, k=4): + return 2 * m + l + n * k + + partial_func = functools.partial(func, n=7) + + exception_message = (r"Some arguments \['l'\] do not have default value, " + "but they are positioned after those with default " + "values. This can not be expressed with ArgSpec.") + with self.assertRaisesRegexp(ValueError, exception_message): + tf_inspect.getargspec(partial_func) + + def testGetArgSpecOnPartialValidArgspec(self): + """Tests getargspec on partial function with valid argspec.""" + + def func(m, n, l, k=4): + return 2 * m + l + n * k + + partial_func = functools.partial(func, n=7, l=2) + argspec = tf_inspect.ArgSpec( + args=['m', 'n', 'l', 'k'], + varargs=None, + keywords=None, + defaults=(7, 2, 4)) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialNoArgumentsLeft(self): + """Tests getargspec on partial function that prunes all arguments.""" + + def func(m, n): + return 2 * m + n + + partial_func = functools.partial(func, 7, 10) + argspec = tf_inspect.ArgSpec( + args=[], varargs=None, keywords=None, defaults=None) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialKeywordArgument(self): + """Tests getargspec on partial function that prunes some arguments.""" + + def func(m, n): + return 2 * m + n + + partial_func = functools.partial(func, n=7) + argspec = tf_inspect.ArgSpec( + args=['m', 'n'], varargs=None, keywords=None, defaults=(7,)) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialKeywordArgumentWithDefaultValue(self): + """Tests getargspec on partial function that prunes argument by keyword.""" + + def func(m=1, n=2): + return 2 * m + n + + partial_func = functools.partial(func, n=7) + argspec = tf_inspect.ArgSpec( + args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7)) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialWithVarargs(self): + """Tests getargspec on partial function with variable arguments.""" + + def func(m, *arg): + return m + len(arg) + + partial_func = functools.partial(func, 7, 8) + argspec = tf_inspect.ArgSpec( + args=[], varargs='arg', keywords=None, defaults=None) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialWithVarkwargs(self): + """Tests getargspec on partial function with variable keyword arguments.""" + + def func(m, n, **kwarg): + return m * n + len(kwarg) + + partial_func = functools.partial(func, 7) + argspec = tf_inspect.ArgSpec( + args=['n'], varargs=None, keywords='kwarg', defaults=None) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialWithDecorator(self): + """Tests getargspec on decorated partial function.""" + + @test_decorator('decorator') + def func(m=1, n=2): + return 2 * m + n + + partial_func = functools.partial(func, n=7) + argspec = tf_inspect.ArgSpec( + args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7)) + + self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + + def testGetArgSpecOnPartialWithDecoratorThatChangesArgspec(self): + """Tests getargspec on partial function with decorated argspec.""" + + argspec = tf_inspect.ArgSpec( + args=['a', 'b', 'c'], + varargs=None, + keywords=None, + defaults=(1, 'hello')) + decorator = tf_decorator.TFDecorator('', test_undecorated_function, '', + argspec) + partial_argspec = tf_inspect.ArgSpec( + args=['a', 'b', 'c'], + varargs=None, + keywords=None, + defaults=(2, 1, 'hello')) + partial_with_decorator = functools.partial(decorator, a=2) + + self.assertEqual(argspec, tf_inspect.getargspec(decorator)) + self.assertEqual(partial_argspec, + tf_inspect.getargspec(partial_with_decorator)) + def testGetDoc(self): self.assertEqual('Test Decorated Function With Defaults Docstring.', tf_inspect.getdoc(test_decorated_function_with_defaults))