More consistently pass function arguments by omitting **kwargs when the target function did not define any and the caller did not include any keywords.
PiperOrigin-RevId: 242066459
This commit is contained in:
parent
f3c82c4646
commit
a1e3d4490d
@ -116,6 +116,9 @@ class CallTreeTransformer(converter.Base):
|
||||
k = self.visit(k)
|
||||
normal_keywords.append(k)
|
||||
if kwargs_arg is None:
|
||||
if not normal_keywords:
|
||||
kwargs = parser.parse_expression('None')
|
||||
else:
|
||||
kwargs = ast_util.keywords_to_dict(normal_keywords)
|
||||
else:
|
||||
kwargs = templates.replace_as_expression(
|
||||
|
@ -34,7 +34,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
self.assertEqual(
|
||||
result.test_fn(None),
|
||||
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
|
||||
self.assertListEqual(self.dynamic_calls, [((), {})])
|
||||
self.assertListEqual(self.dynamic_calls, [((), None)])
|
||||
|
||||
def test_function_with_expression_in_argument(self):
|
||||
|
||||
@ -46,8 +46,8 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
result.test_fn(None, None),
|
||||
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
|
||||
self.assertListEqual(self.dynamic_calls, [
|
||||
((), {}),
|
||||
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 7,), {}),
|
||||
((), None),
|
||||
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 7,), None),
|
||||
])
|
||||
|
||||
def test_function_with_call_in_argument(self):
|
||||
@ -60,8 +60,8 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
result.test_fn(None, None),
|
||||
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
|
||||
self.assertListEqual(self.dynamic_calls, [
|
||||
((), {}),
|
||||
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL,), {}),
|
||||
((), None),
|
||||
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL,), None),
|
||||
])
|
||||
|
||||
def test_function_with_kwarg(self):
|
||||
@ -100,7 +100,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
with self.converted(test_fn, call_trees, {'f': f}) as result:
|
||||
self.assertEqual(result.test_fn(),
|
||||
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 11)
|
||||
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {})])
|
||||
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), None)])
|
||||
|
||||
def test_function_with_kwargs_keywords(self):
|
||||
|
||||
@ -124,7 +124,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
with self.converted(TestClass.test_method, call_trees, {}) as result:
|
||||
self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1,
|
||||
result.test_method(tc, 1))
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), {})])
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), None)])
|
||||
|
||||
def test_object_method(self):
|
||||
|
||||
@ -137,7 +137,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
with self.converted(tc.test_method, call_trees, {}) as result:
|
||||
self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1,
|
||||
result.test_method(tc, 1))
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), {})])
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), None)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -172,7 +172,10 @@ def _call_unconverted(f, args, kwargs):
|
||||
if inspect_utils.istfmethodtarget(f):
|
||||
return f.__self__.call(args, kwargs)
|
||||
|
||||
if kwargs is not None:
|
||||
return f(*args, **kwargs)
|
||||
else:
|
||||
return f(*args)
|
||||
|
||||
|
||||
def _is_known_loaded_type(f, module_name, entity_name):
|
||||
@ -224,7 +227,10 @@ def converted_call(f, owner, options, args, kwargs):
|
||||
f = getattr(owner, f)
|
||||
|
||||
if inspect_utils.isbuiltin(f):
|
||||
if kwargs:
|
||||
return py_builtins.overload_of(f)(*args, **kwargs)
|
||||
else:
|
||||
return py_builtins.overload_of(f)(*args)
|
||||
|
||||
if _is_known_loaded_type(f, 'weakref', 'ref'):
|
||||
logging.log(2, 'Permanently whitelisted: %s: weakref', f)
|
||||
@ -275,6 +281,7 @@ def converted_call(f, owner, options, args, kwargs):
|
||||
new_kwargs = {}
|
||||
if f.keywords is not None:
|
||||
new_kwargs.update(f.keywords)
|
||||
if kwargs is not None:
|
||||
new_kwargs.update(kwargs)
|
||||
kwargs = new_kwargs
|
||||
f = f.func
|
||||
@ -317,7 +324,11 @@ def converted_call(f, owner, options, args, kwargs):
|
||||
if logging.has_verbosity(2):
|
||||
logging.log(2, 'Defaults of %s : %s', converted_f,
|
||||
converted_f.__defaults__)
|
||||
callargs = tf_inspect.getcallargs(converted_f, *effective_args, **kwargs)
|
||||
if kwargs is not None:
|
||||
callargs = tf_inspect.getcallargs(
|
||||
converted_f, *effective_args, **kwargs)
|
||||
else:
|
||||
callargs = tf_inspect.getcallargs(converted_f, *effective_args)
|
||||
formatted_callargs = '\n'.join(
|
||||
' {}: {}'.format(k, v) for k, v in callargs.items())
|
||||
logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)
|
||||
@ -342,7 +353,10 @@ def converted_call(f, owner, options, args, kwargs):
|
||||
|
||||
return _call_unconverted(f, args, kwargs)
|
||||
|
||||
if kwargs is not None:
|
||||
result = converted_f(*effective_args, **kwargs)
|
||||
else:
|
||||
result = converted_f(*effective_args)
|
||||
|
||||
return result
|
||||
|
||||
|
@ -415,6 +415,16 @@ class ApiTest(test.TestCase):
|
||||
# The constant has static shape so the result is a primitive not a Tensor.
|
||||
self.assertEqual(x, 1)
|
||||
|
||||
def test_converted_call_no_kwargs_allowed(self):
|
||||
|
||||
def f(*args):
|
||||
# Note: np.broadcast rejects any **kwargs, even *{}
|
||||
return np.broadcast(args[:1])
|
||||
|
||||
opts = converter.ConversionOptions(internal_convert_user_code=False)
|
||||
|
||||
self.assertIsNotNone(api.converted_call(f, None, opts, (1, 2, 3, 4), None))
|
||||
|
||||
def test_converted_call_whitelisted_method(self):
|
||||
|
||||
opts = converter.ConversionOptions()
|
||||
|
Loading…
Reference in New Issue
Block a user