More consistently pass function arguments by omitting **kwargs when the target function did not define any and the caller did not include any keywords.

PiperOrigin-RevId: 242066459
This commit is contained in:
Dan Moldovan 2019-04-04 21:23:49 -07:00 committed by TensorFlower Gardener
parent f3c82c4646
commit a1e3d4490d
4 changed files with 41 additions and 14 deletions

View File

@ -116,6 +116,9 @@ class CallTreeTransformer(converter.Base):
k = self.visit(k)
normal_keywords.append(k)
if kwargs_arg is None:
if not normal_keywords:
kwargs = parser.parse_expression('None')
else:
kwargs = ast_util.keywords_to_dict(normal_keywords)
else:
kwargs = templates.replace_as_expression(

View File

@ -34,7 +34,7 @@ class CallTreesTest(converter_testing.TestCase):
self.assertEqual(
result.test_fn(None),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
self.assertListEqual(self.dynamic_calls, [((), {})])
self.assertListEqual(self.dynamic_calls, [((), None)])
def test_function_with_expression_in_argument(self):
@ -46,8 +46,8 @@ class CallTreesTest(converter_testing.TestCase):
result.test_fn(None, None),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
self.assertListEqual(self.dynamic_calls, [
((), {}),
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 7,), {}),
((), None),
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 7,), None),
])
def test_function_with_call_in_argument(self):
@ -60,8 +60,8 @@ class CallTreesTest(converter_testing.TestCase):
result.test_fn(None, None),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
self.assertListEqual(self.dynamic_calls, [
((), {}),
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL,), {}),
((), None),
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL,), None),
])
def test_function_with_kwarg(self):
@ -100,7 +100,7 @@ class CallTreesTest(converter_testing.TestCase):
with self.converted(test_fn, call_trees, {'f': f}) as result:
self.assertEqual(result.test_fn(),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 11)
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {})])
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), None)])
def test_function_with_kwargs_keywords(self):
@ -124,7 +124,7 @@ class CallTreesTest(converter_testing.TestCase):
with self.converted(TestClass.test_method, call_trees, {}) as result:
self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1,
result.test_method(tc, 1))
self.assertListEqual(self.dynamic_calls, [((1,), {})])
self.assertListEqual(self.dynamic_calls, [((1,), None)])
def test_object_method(self):
@ -137,7 +137,7 @@ class CallTreesTest(converter_testing.TestCase):
with self.converted(tc.test_method, call_trees, {}) as result:
self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1,
result.test_method(tc, 1))
self.assertListEqual(self.dynamic_calls, [((1,), {})])
self.assertListEqual(self.dynamic_calls, [((1,), None)])
if __name__ == '__main__':

View File

@ -172,7 +172,10 @@ def _call_unconverted(f, args, kwargs):
if inspect_utils.istfmethodtarget(f):
return f.__self__.call(args, kwargs)
if kwargs is not None:
return f(*args, **kwargs)
else:
return f(*args)
def _is_known_loaded_type(f, module_name, entity_name):
@ -224,7 +227,10 @@ def converted_call(f, owner, options, args, kwargs):
f = getattr(owner, f)
if inspect_utils.isbuiltin(f):
if kwargs:
return py_builtins.overload_of(f)(*args, **kwargs)
else:
return py_builtins.overload_of(f)(*args)
if _is_known_loaded_type(f, 'weakref', 'ref'):
logging.log(2, 'Permanently whitelisted: %s: weakref', f)
@ -275,6 +281,7 @@ def converted_call(f, owner, options, args, kwargs):
new_kwargs = {}
if f.keywords is not None:
new_kwargs.update(f.keywords)
if kwargs is not None:
new_kwargs.update(kwargs)
kwargs = new_kwargs
f = f.func
@ -317,7 +324,11 @@ def converted_call(f, owner, options, args, kwargs):
if logging.has_verbosity(2):
logging.log(2, 'Defaults of %s : %s', converted_f,
converted_f.__defaults__)
callargs = tf_inspect.getcallargs(converted_f, *effective_args, **kwargs)
if kwargs is not None:
callargs = tf_inspect.getcallargs(
converted_f, *effective_args, **kwargs)
else:
callargs = tf_inspect.getcallargs(converted_f, *effective_args)
formatted_callargs = '\n'.join(
' {}: {}'.format(k, v) for k, v in callargs.items())
logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)
@ -342,7 +353,10 @@ def converted_call(f, owner, options, args, kwargs):
return _call_unconverted(f, args, kwargs)
if kwargs is not None:
result = converted_f(*effective_args, **kwargs)
else:
result = converted_f(*effective_args)
return result

View File

@ -415,6 +415,16 @@ class ApiTest(test.TestCase):
# The constant has static shape so the result is a primitive not a Tensor.
self.assertEqual(x, 1)
def test_converted_call_no_kwargs_allowed(self):
def f(*args):
# Note: np.broadcast rejects any **kwargs, even *{}
return np.broadcast(args[:1])
opts = converter.ConversionOptions(internal_convert_user_code=False)
self.assertIsNotNone(api.converted_call(f, None, opts, (1, 2, 3, 4), None))
def test_converted_call_whitelisted_method(self):
opts = converter.ConversionOptions()