Fix a super confusing bug in tf_inspect.
If a default value was something that evaluates to False (like 0.0, or Tensor([0.0]), we would not get defaults. Same as in cl/235632825: This does not solve all tf.function-and-partial bugs but makes debugging simpler. PiperOrigin-RevId: 238621317
This commit is contained in:
parent
3aa72b5efc
commit
6c4cccfc96
@ -198,7 +198,9 @@ def _get_argspec_for_partial(obj):
|
|||||||
# Partial function may give default value to any argument, therefore length
|
# Partial function may give default value to any argument, therefore length
|
||||||
# of default value list must be len(args) to allow each argument to
|
# of default value list must be len(args) to allow each argument to
|
||||||
# potentially be given a default value.
|
# potentially be given a default value.
|
||||||
all_defaults = [None] * len(args)
|
no_default = object()
|
||||||
|
all_defaults = [no_default] * len(args)
|
||||||
|
|
||||||
if defaults:
|
if defaults:
|
||||||
all_defaults[-len(defaults):] = defaults
|
all_defaults[-len(defaults):] = defaults
|
||||||
|
|
||||||
@ -208,7 +210,8 @@ def _get_argspec_for_partial(obj):
|
|||||||
all_defaults[idx] = default
|
all_defaults[idx] = default
|
||||||
|
|
||||||
# Find first argument with default value set.
|
# Find first argument with default value set.
|
||||||
first_default = next((idx for idx, x in enumerate(all_defaults) if x), None)
|
first_default = next(
|
||||||
|
(idx for idx, x in enumerate(all_defaults) if x is not no_default), None)
|
||||||
|
|
||||||
# If no default values are found, return ArgSpec with defaults=None.
|
# If no default values are found, return ArgSpec with defaults=None.
|
||||||
if first_default is None:
|
if first_default is None:
|
||||||
@ -217,7 +220,7 @@ def _get_argspec_for_partial(obj):
|
|||||||
# Checks if all arguments have default value set after first one.
|
# Checks if all arguments have default value set after first one.
|
||||||
invalid_default_values = [
|
invalid_default_values = [
|
||||||
args[i] for i, j in enumerate(all_defaults)
|
args[i] for i, j in enumerate(all_defaults)
|
||||||
if j is None and i > first_default
|
if j is no_default and i > first_default
|
||||||
]
|
]
|
||||||
|
|
||||||
if invalid_default_values:
|
if invalid_default_values:
|
||||||
|
@ -122,6 +122,20 @@ class TfInspectTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
|
||||||
|
|
||||||
|
def testGetArgSpecOnPartialArgumentWithConvertibleToFalse(self):
|
||||||
|
"""Tests getargspec on partial function with args that convert to False."""
|
||||||
|
|
||||||
|
def func(m, n):
|
||||||
|
return 2 * m + n
|
||||||
|
|
||||||
|
partial_func = functools.partial(func, m=0)
|
||||||
|
|
||||||
|
exception_message = (r"Some arguments \['n'\] do not have default value, "
|
||||||
|
"but they are positioned after those with default "
|
||||||
|
"values. This can not be expressed with ArgSpec.")
|
||||||
|
with self.assertRaisesRegexp(ValueError, exception_message):
|
||||||
|
tf_inspect.getargspec(partial_func)
|
||||||
|
|
||||||
def testGetArgSpecOnPartialInvalidArgspec(self):
|
def testGetArgSpecOnPartialInvalidArgspec(self):
|
||||||
"""Tests getargspec on partial function that doesn't have valid argspec."""
|
"""Tests getargspec on partial function that doesn't have valid argspec."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user