Properly update the converted function's __kwdefaults__.

Fix getcallargs to avoid parameter clashes when the kwargs include "func". This same fix has been applied in the Python 3 library. See https://bugs.python.org/issue20108.

PiperOrigin-RevId: 249154695
This commit is contained in:
Dan Moldovan 2019-05-20 16:55:43 -07:00 committed by TensorFlower Gardener
parent c93ebee01e
commit e45cb1770d
4 changed files with 43 additions and 5 deletions

View File

@ -26,11 +26,14 @@ import os
import pdb
import re
import sys
import textwrap
import traceback
import textwrap
from enum import Enum
# pylint:disable=g-bad-import-order
import six
# pylint:enable=g-bad-import-order
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import conversion
from tensorflow.python.autograph.operators import py_builtins
@ -425,11 +428,16 @@ def converted_call(f, owner, options, args, kwargs):
if logging.has_verbosity(2):
logging.log(2, 'Defaults of %s : %s', converted_f,
converted_f.__defaults__)
if six.PY3:
logging.log(2, 'KW defaults of %s : %s',
converted_f, converted_f.__kwdefaults__)
if kwargs is not None:
callargs = tf_inspect.getcallargs(converted_f, *effective_args,
**kwargs)
else:
callargs = tf_inspect.getcallargs(converted_f, *effective_args)
formatted_callargs = '\n'.join(
' {}: {}'.format(k, v) for k, v in callargs.items())
logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)

View File

@ -283,6 +283,8 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
# Attach the default argument to the converted function.
converted_entity.__defaults__ = entity.__defaults__
if hasattr(entity, '__kwdefaults__'):
converted_entity.__kwdefaults__ = entity.__kwdefaults__
return converted_entity

View File

@ -257,12 +257,12 @@ def getfullargspec(obj):
return _getfullargspec(target)
def getcallargs(func, *positional, **named):
def getcallargs(*func_and_positional, **named):
"""TFDecorator-aware replacement for inspect.getcallargs.
Args:
func: A callable, possibly decorated
*positional: The positional arguments that would be passed to `func`.
*func_and_positional: A callable, possibly decorated, followed by any
positional arguments that would be passed to `func`.
**named: The named argument dictionary that would be passed to `func`.
Returns:
@ -273,6 +273,8 @@ def getcallargs(func, *positional, **named):
it. If no attached decorators modify argspec, the final unwrapped target's
argspec will be used.
"""
func = func_and_positional[0]
positional = func_and_positional[1:]
argspec = getfullargspec(func)
call_args = named.copy()
this = getattr(func, 'im_self', None) or getattr(func, '__self__', None)
@ -285,6 +287,10 @@ def getcallargs(func, *positional, **named):
for arg, value in zip(argspec.args[-default_count:], argspec.defaults):
if arg not in call_args:
call_args[arg] = value
if argspec.kwonlydefaults is not None:
for k, v in argspec.kwonlydefaults.items():
if k not in call_args:
call_args[k] = v
return call_args

View File

@ -594,6 +594,28 @@ class TfInspectGetCallArgsTest(test.TestCase):
self.assertEqual({}, tf_inspect.getcallargs(empty))
def testClashingParameterNames(self):
def func(positional, func=1, func_and_positional=2, kwargs=3):
return positional, func, func_and_positional, kwargs
kwargs = {}
self.assertEqual(
tf_inspect.getcallargs(func, 0, **kwargs), {
'positional': 0,
'func': 1,
'func_and_positional': 2,
'kwargs': 3
})
kwargs = dict(func=4, func_and_positional=5, kwargs=6)
self.assertEqual(
tf_inspect.getcallargs(func, 0, **kwargs), {
'positional': 0,
'func': 4,
'func_and_positional': 5,
'kwargs': 6
})
def testUnboundFuncWithOneParamPositional(self):
def func(a):