Fix a super confusing bug in tf_inspect.
If a default value was something that evaluates to False (like 0.0, or Tensor([0.0]), we would not get defaults. Same as in cl/235632825: This does not solve all tf.function-and-partial bugs but makes debugging simpler. PiperOrigin-RevId: 238621317
This commit is contained in:
parent
3aa72b5efc
commit
6c4cccfc96
@ -198,7 +198,9 @@ def _get_argspec_for_partial(obj):
|
||||
# Partial function may give default value to any argument, therefore length
|
||||
# of default value list must be len(args) to allow each argument to
|
||||
# potentially be given a default value.
|
||||
all_defaults = [None] * len(args)
|
||||
no_default = object()
|
||||
all_defaults = [no_default] * len(args)
|
||||
|
||||
if defaults:
|
||||
all_defaults[-len(defaults):] = defaults
|
||||
|
||||
@ -208,7 +210,8 @@ def _get_argspec_for_partial(obj):
|
||||
all_defaults[idx] = default
|
||||
|
||||
# Find first argument with default value set.
|
||||
first_default = next((idx for idx, x in enumerate(all_defaults) if x), None)
|
||||
first_default = next(
|
||||
(idx for idx, x in enumerate(all_defaults) if x is not no_default), None)
|
||||
|
||||
# If no default values are found, return ArgSpec with defaults=None.
|
||||
if first_default is None:
|
||||
@ -217,7 +220,7 @@ def _get_argspec_for_partial(obj):
|
||||
# Checks if all arguments have default value set after first one.
|
||||
invalid_default_values = [
|
||||
args[i] for i, j in enumerate(all_defaults)
|
||||
if j is None and i > first_default
|
||||
if j is no_default and i > first_default
|
||||
]
|
||||
|
||||
if invalid_default_values:
|
||||
|
@ -122,6 +122,20 @@ class TfInspectTest(test.TestCase):
|
||||
|
||||
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||
|
||||
def testGetArgSpecOnPartialArgumentWithConvertibleToFalse(self):
|
||||
"""Tests getargspec on partial function with args that convert to False."""
|
||||
|
||||
def func(m, n):
|
||||
return 2 * m + n
|
||||
|
||||
partial_func = functools.partial(func, m=0)
|
||||
|
||||
exception_message = (r"Some arguments \['n'\] do not have default value, "
|
||||
"but they are positioned after those with default "
|
||||
"values. This can not be expressed with ArgSpec.")
|
||||
with self.assertRaisesRegexp(ValueError, exception_message):
|
||||
tf_inspect.getargspec(partial_func)
|
||||
|
||||
def testGetArgSpecOnPartialInvalidArgspec(self):
|
||||
"""Tests getargspec on partial function that doesn't have valid argspec."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user