Properly update the converted function's __kwdefaults__.
Fix getcallargs to avoid parameter clashes when the kwargs include "func". This same fix has been applied in the Python 3 library. See https://bugs.python.org/issue20108. PiperOrigin-RevId: 249154695
This commit is contained in:
parent
c93ebee01e
commit
e45cb1770d
@ -26,11 +26,14 @@ import os
|
|||||||
import pdb
|
import pdb
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
|
||||||
import traceback
|
import traceback
|
||||||
|
import textwrap
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
# pylint:disable=g-bad-import-order
|
||||||
|
import six
|
||||||
|
# pylint:enable=g-bad-import-order
|
||||||
|
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
from tensorflow.python.autograph.impl import conversion
|
from tensorflow.python.autograph.impl import conversion
|
||||||
from tensorflow.python.autograph.operators import py_builtins
|
from tensorflow.python.autograph.operators import py_builtins
|
||||||
@ -425,11 +428,16 @@ def converted_call(f, owner, options, args, kwargs):
|
|||||||
if logging.has_verbosity(2):
|
if logging.has_verbosity(2):
|
||||||
logging.log(2, 'Defaults of %s : %s', converted_f,
|
logging.log(2, 'Defaults of %s : %s', converted_f,
|
||||||
converted_f.__defaults__)
|
converted_f.__defaults__)
|
||||||
|
if six.PY3:
|
||||||
|
logging.log(2, 'KW defaults of %s : %s',
|
||||||
|
converted_f, converted_f.__kwdefaults__)
|
||||||
|
|
||||||
if kwargs is not None:
|
if kwargs is not None:
|
||||||
callargs = tf_inspect.getcallargs(converted_f, *effective_args,
|
callargs = tf_inspect.getcallargs(converted_f, *effective_args,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
else:
|
else:
|
||||||
callargs = tf_inspect.getcallargs(converted_f, *effective_args)
|
callargs = tf_inspect.getcallargs(converted_f, *effective_args)
|
||||||
|
|
||||||
formatted_callargs = '\n'.join(
|
formatted_callargs = '\n'.join(
|
||||||
' {}: {}'.format(k, v) for k, v in callargs.items())
|
' {}: {}'.format(k, v) for k, v in callargs.items())
|
||||||
logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)
|
logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)
|
||||||
|
@ -283,6 +283,8 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
|
|||||||
if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
|
if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
|
||||||
# Attach the default argument to the converted function.
|
# Attach the default argument to the converted function.
|
||||||
converted_entity.__defaults__ = entity.__defaults__
|
converted_entity.__defaults__ = entity.__defaults__
|
||||||
|
if hasattr(entity, '__kwdefaults__'):
|
||||||
|
converted_entity.__kwdefaults__ = entity.__kwdefaults__
|
||||||
|
|
||||||
return converted_entity
|
return converted_entity
|
||||||
|
|
||||||
|
@ -257,12 +257,12 @@ def getfullargspec(obj):
|
|||||||
return _getfullargspec(target)
|
return _getfullargspec(target)
|
||||||
|
|
||||||
|
|
||||||
def getcallargs(func, *positional, **named):
|
def getcallargs(*func_and_positional, **named):
|
||||||
"""TFDecorator-aware replacement for inspect.getcallargs.
|
"""TFDecorator-aware replacement for inspect.getcallargs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func: A callable, possibly decorated
|
*func_and_positional: A callable, possibly decorated, followed by any
|
||||||
*positional: The positional arguments that would be passed to `func`.
|
positional arguments that would be passed to `func`.
|
||||||
**named: The named argument dictionary that would be passed to `func`.
|
**named: The named argument dictionary that would be passed to `func`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -273,6 +273,8 @@ def getcallargs(func, *positional, **named):
|
|||||||
it. If no attached decorators modify argspec, the final unwrapped target's
|
it. If no attached decorators modify argspec, the final unwrapped target's
|
||||||
argspec will be used.
|
argspec will be used.
|
||||||
"""
|
"""
|
||||||
|
func = func_and_positional[0]
|
||||||
|
positional = func_and_positional[1:]
|
||||||
argspec = getfullargspec(func)
|
argspec = getfullargspec(func)
|
||||||
call_args = named.copy()
|
call_args = named.copy()
|
||||||
this = getattr(func, 'im_self', None) or getattr(func, '__self__', None)
|
this = getattr(func, 'im_self', None) or getattr(func, '__self__', None)
|
||||||
@ -285,6 +287,10 @@ def getcallargs(func, *positional, **named):
|
|||||||
for arg, value in zip(argspec.args[-default_count:], argspec.defaults):
|
for arg, value in zip(argspec.args[-default_count:], argspec.defaults):
|
||||||
if arg not in call_args:
|
if arg not in call_args:
|
||||||
call_args[arg] = value
|
call_args[arg] = value
|
||||||
|
if argspec.kwonlydefaults is not None:
|
||||||
|
for k, v in argspec.kwonlydefaults.items():
|
||||||
|
if k not in call_args:
|
||||||
|
call_args[k] = v
|
||||||
return call_args
|
return call_args
|
||||||
|
|
||||||
|
|
||||||
|
@ -594,6 +594,28 @@ class TfInspectGetCallArgsTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertEqual({}, tf_inspect.getcallargs(empty))
|
self.assertEqual({}, tf_inspect.getcallargs(empty))
|
||||||
|
|
||||||
|
def testClashingParameterNames(self):
|
||||||
|
|
||||||
|
def func(positional, func=1, func_and_positional=2, kwargs=3):
|
||||||
|
return positional, func, func_and_positional, kwargs
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
self.assertEqual(
|
||||||
|
tf_inspect.getcallargs(func, 0, **kwargs), {
|
||||||
|
'positional': 0,
|
||||||
|
'func': 1,
|
||||||
|
'func_and_positional': 2,
|
||||||
|
'kwargs': 3
|
||||||
|
})
|
||||||
|
kwargs = dict(func=4, func_and_positional=5, kwargs=6)
|
||||||
|
self.assertEqual(
|
||||||
|
tf_inspect.getcallargs(func, 0, **kwargs), {
|
||||||
|
'positional': 0,
|
||||||
|
'func': 4,
|
||||||
|
'func_and_positional': 5,
|
||||||
|
'kwargs': 6
|
||||||
|
})
|
||||||
|
|
||||||
def testUnboundFuncWithOneParamPositional(self):
|
def testUnboundFuncWithOneParamPositional(self):
|
||||||
|
|
||||||
def func(a):
|
def func(a):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user