From 2fe93816b0bf9f9171df011e0f572c64932f9db2 Mon Sep 17 00:00:00 2001 From: Dan Moldovan <mdan@google.com> Date: Thu, 6 Jun 2019 09:25:44 -0700 Subject: [PATCH] Fix incorrect assumption that all methods have a positional first argument. PiperOrigin-RevId: 251865772 --- tensorflow/python/util/function_utils.py | 9 ++++++--- tensorflow/python/util/function_utils_test.py | 13 +++++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py index d1cc67ce38e..4c1e6eddaa2 100644 --- a/tensorflow/python/util/function_utils.py +++ b/tensorflow/python/util/function_utils.py @@ -27,7 +27,7 @@ from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect -def _is_bounded_method(fn): +def _is_bound_method(fn): _, fn = tf_decorator.unwrap(fn) return tf_inspect.ismethod(fn) and (fn.__self__ is not None) @@ -55,8 +55,11 @@ def fn_args(fn): if _is_callable_object(fn): fn = fn.__call__ args = tf_inspect.getfullargspec(fn).args - if _is_bounded_method(fn): - args.pop(0) # remove `self` or `cls` + if _is_bound_method(fn) and args: + # If it's a bound method, it may or may not have a self/cls first + # argument; for example, self could be captured in *args. + # If it does have a positional argument, it is self/cls. + args.pop(0) return tuple(args) diff --git a/tensorflow/python/util/function_utils_test.py b/tensorflow/python/util/function_utils_test.py index e5b0843e4b7..8fc740492c6 100644 --- a/tensorflow/python/util/function_utils_test.py +++ b/tensorflow/python/util/function_utils_test.py @@ -50,7 +50,7 @@ class FnArgsTest(test.TestCase): self.assertEqual(('a', 'b'), function_utils.fn_args(Foo())) - def test_bounded_method(self): + def test_bound_method(self): class Foo(object): @@ -59,6 +59,15 @@ class FnArgsTest(test.TestCase): self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar)) + def test_bound_method_no_self(self): + + class Foo(object): + + def bar(*args): # pylint:disable=no-method-argument + return args[1] + args[2] + + self.assertEqual((), function_utils.fn_args(Foo().bar)) + def test_partial_function(self): expected_test_arg = 123 @@ -159,7 +168,7 @@ class HasKwargsTest(test.TestCase): del x self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs())) - def test_bounded_method(self): + def test_bound_method(self): class FooHasKwargs(object):