Properly update the converted function's __kwdefaults__.
Fix getcallargs to avoid parameter clashes when the kwargs include "func". This same fix has been applied in the Python 3 library. See https://bugs.python.org/issue20108. PiperOrigin-RevId: 249154695
This commit is contained in:
parent
c93ebee01e
commit
e45cb1770d
@ -26,11 +26,14 @@ import os
|
||||
import pdb
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
|
||||
import textwrap
|
||||
from enum import Enum
|
||||
|
||||
# pylint:disable=g-bad-import-order
|
||||
import six
|
||||
# pylint:enable=g-bad-import-order
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.impl import conversion
|
||||
from tensorflow.python.autograph.operators import py_builtins
|
||||
@ -425,11 +428,16 @@ def converted_call(f, owner, options, args, kwargs):
|
||||
if logging.has_verbosity(2):
|
||||
logging.log(2, 'Defaults of %s : %s', converted_f,
|
||||
converted_f.__defaults__)
|
||||
if six.PY3:
|
||||
logging.log(2, 'KW defaults of %s : %s',
|
||||
converted_f, converted_f.__kwdefaults__)
|
||||
|
||||
if kwargs is not None:
|
||||
callargs = tf_inspect.getcallargs(converted_f, *effective_args,
|
||||
**kwargs)
|
||||
else:
|
||||
callargs = tf_inspect.getcallargs(converted_f, *effective_args)
|
||||
|
||||
formatted_callargs = '\n'.join(
|
||||
' {}: {}'.format(k, v) for k, v in callargs.items())
|
||||
logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)
|
||||
|
@ -283,6 +283,8 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
|
||||
if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
|
||||
# Attach the default argument to the converted function.
|
||||
converted_entity.__defaults__ = entity.__defaults__
|
||||
if hasattr(entity, '__kwdefaults__'):
|
||||
converted_entity.__kwdefaults__ = entity.__kwdefaults__
|
||||
|
||||
return converted_entity
|
||||
|
||||
|
@ -257,12 +257,12 @@ def getfullargspec(obj):
|
||||
return _getfullargspec(target)
|
||||
|
||||
|
||||
def getcallargs(func, *positional, **named):
|
||||
def getcallargs(*func_and_positional, **named):
|
||||
"""TFDecorator-aware replacement for inspect.getcallargs.
|
||||
|
||||
Args:
|
||||
func: A callable, possibly decorated
|
||||
*positional: The positional arguments that would be passed to `func`.
|
||||
*func_and_positional: A callable, possibly decorated, followed by any
|
||||
positional arguments that would be passed to `func`.
|
||||
**named: The named argument dictionary that would be passed to `func`.
|
||||
|
||||
Returns:
|
||||
@ -273,6 +273,8 @@ def getcallargs(func, *positional, **named):
|
||||
it. If no attached decorators modify argspec, the final unwrapped target's
|
||||
argspec will be used.
|
||||
"""
|
||||
func = func_and_positional[0]
|
||||
positional = func_and_positional[1:]
|
||||
argspec = getfullargspec(func)
|
||||
call_args = named.copy()
|
||||
this = getattr(func, 'im_self', None) or getattr(func, '__self__', None)
|
||||
@ -285,6 +287,10 @@ def getcallargs(func, *positional, **named):
|
||||
for arg, value in zip(argspec.args[-default_count:], argspec.defaults):
|
||||
if arg not in call_args:
|
||||
call_args[arg] = value
|
||||
if argspec.kwonlydefaults is not None:
|
||||
for k, v in argspec.kwonlydefaults.items():
|
||||
if k not in call_args:
|
||||
call_args[k] = v
|
||||
return call_args
|
||||
|
||||
|
||||
|
@ -594,6 +594,28 @@ class TfInspectGetCallArgsTest(test.TestCase):
|
||||
|
||||
self.assertEqual({}, tf_inspect.getcallargs(empty))
|
||||
|
||||
def testClashingParameterNames(self):
|
||||
|
||||
def func(positional, func=1, func_and_positional=2, kwargs=3):
|
||||
return positional, func, func_and_positional, kwargs
|
||||
|
||||
kwargs = {}
|
||||
self.assertEqual(
|
||||
tf_inspect.getcallargs(func, 0, **kwargs), {
|
||||
'positional': 0,
|
||||
'func': 1,
|
||||
'func_and_positional': 2,
|
||||
'kwargs': 3
|
||||
})
|
||||
kwargs = dict(func=4, func_and_positional=5, kwargs=6)
|
||||
self.assertEqual(
|
||||
tf_inspect.getcallargs(func, 0, **kwargs), {
|
||||
'positional': 0,
|
||||
'func': 4,
|
||||
'func_and_positional': 5,
|
||||
'kwargs': 6
|
||||
})
|
||||
|
||||
def testUnboundFuncWithOneParamPositional(self):
|
||||
|
||||
def func(a):
|
||||
|
Loading…
x
Reference in New Issue
Block a user