Fix a super confusing bug in tf_inspect.

If a default value was something that evaluates to False (like 0.0, or Tensor([0.0]), we would not get defaults.

Same as in cl/235632825: This does not solve all tf.function-and-partial bugs but makes debugging simpler.

PiperOrigin-RevId: 238621317
This commit is contained in:
Vojtech Bardiovsky 2019-03-15 04:41:26 -07:00 committed by TensorFlower Gardener
parent 3aa72b5efc
commit 6c4cccfc96
2 changed files with 20 additions and 3 deletions

View File

@ -198,7 +198,9 @@ def _get_argspec_for_partial(obj):
# Partial function may give default value to any argument, therefore length
# of default value list must be len(args) to allow each argument to
# potentially be given a default value.
all_defaults = [None] * len(args)
no_default = object()
all_defaults = [no_default] * len(args)
if defaults:
all_defaults[-len(defaults):] = defaults
@ -208,7 +210,8 @@ def _get_argspec_for_partial(obj):
all_defaults[idx] = default
# Find first argument with default value set.
first_default = next((idx for idx, x in enumerate(all_defaults) if x), None)
first_default = next(
(idx for idx, x in enumerate(all_defaults) if x is not no_default), None)
# If no default values are found, return ArgSpec with defaults=None.
if first_default is None:
@ -217,7 +220,7 @@ def _get_argspec_for_partial(obj):
# Checks if all arguments have default value set after first one.
invalid_default_values = [
args[i] for i, j in enumerate(all_defaults)
if j is None and i > first_default
if j is no_default and i > first_default
]
if invalid_default_values:

View File

@ -122,6 +122,20 @@ class TfInspectTest(test.TestCase):
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialArgumentWithConvertibleToFalse(self):
"""Tests getargspec on partial function with args that convert to False."""
def func(m, n):
return 2 * m + n
partial_func = functools.partial(func, m=0)
exception_message = (r"Some arguments \['n'\] do not have default value, "
"but they are positioned after those with default "
"values. This can not be expressed with ArgSpec.")
with self.assertRaisesRegexp(ValueError, exception_message):
tf_inspect.getargspec(partial_func)
def testGetArgSpecOnPartialInvalidArgspec(self):
"""Tests getargspec on partial function that doesn't have valid argspec."""