diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py index a53f03a03fa..906776d5dba 100644 --- a/tensorflow/python/util/tf_inspect.py +++ b/tensorflow/python/util/tf_inspect.py @@ -198,7 +198,9 @@ def _get_argspec_for_partial(obj): # Partial function may give default value to any argument, therefore length # of default value list must be len(args) to allow each argument to # potentially be given a default value. - all_defaults = [None] * len(args) + no_default = object() + all_defaults = [no_default] * len(args) + if defaults: all_defaults[-len(defaults):] = defaults @@ -208,7 +210,8 @@ def _get_argspec_for_partial(obj): all_defaults[idx] = default # Find first argument with default value set. - first_default = next((idx for idx, x in enumerate(all_defaults) if x), None) + first_default = next( + (idx for idx, x in enumerate(all_defaults) if x is not no_default), None) # If no default values are found, return ArgSpec with defaults=None. if first_default is None: @@ -217,7 +220,7 @@ def _get_argspec_for_partial(obj): # Checks if all arguments have default value set after first one. invalid_default_values = [ args[i] for i, j in enumerate(all_defaults) - if j is None and i > first_default + if j is no_default and i > first_default ] if invalid_default_values: diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py index 910848e67f9..7c030d69216 100644 --- a/tensorflow/python/util/tf_inspect_test.py +++ b/tensorflow/python/util/tf_inspect_test.py @@ -122,6 +122,20 @@ class TfInspectTest(test.TestCase): self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + def testGetArgSpecOnPartialArgumentWithConvertibleToFalse(self): + """Tests getargspec on partial function with args that convert to False.""" + + def func(m, n): + return 2 * m + n + + partial_func = functools.partial(func, m=0) + + exception_message = (r"Some arguments \['n'\] do not have default value, " + "but they are positioned after those with default " + "values. This can not be expressed with ArgSpec.") + with self.assertRaisesRegexp(ValueError, exception_message): + tf_inspect.getargspec(partial_func) + def testGetArgSpecOnPartialInvalidArgspec(self): """Tests getargspec on partial function that doesn't have valid argspec."""