From 8ba59743a42885635ab02e1cb5687b136e5a537d Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Sat, 12 Oct 2019 14:08:58 -0700 Subject: [PATCH] Handle cases of multiple *args and **kwargs, which are legal in Python3. PiperOrigin-RevId: 274371522 --- tensorflow/python/autograph/converters/BUILD | 12 +- .../python/autograph/converters/call_trees.py | 110 +++++++++++------- .../autograph/converters/call_trees_test.py | 36 +++++- tensorflow/python/autograph/impl/BUILD | 4 +- tensorflow/python/autograph/impl/api.py | 2 +- .../python/autograph/impl/api_py3_test.py | 2 +- 6 files changed, 117 insertions(+), 49 deletions(-) diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index 0f6189ceaa7..d438dc6e3ba 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -42,6 +42,7 @@ py_library( "//tensorflow/python/autograph/lang", "//tensorflow/python/autograph/pyct", "//tensorflow/python/autograph/pyct/static_analysis", + "//tensorflow/python/autograph/utils", "@gast_archive//:gast", ], ) @@ -82,13 +83,18 @@ py_test( py_test( name = "call_trees_test", srcs = ["call_trees_test.py"], - srcs_version = "PY2AND3", - tags = ["no_windows"], + python_version = "PY3", + srcs_version = "PY3", + tags = [ + "no_oss_py2", + "no_pip", + "no_windows", + "nopip", + ], deps = [ ":converters", "//tensorflow/python:client_testlib", "//tensorflow/python/autograph/core:test_lib", - "//tensorflow/python/autograph/impl", ], ) diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py index 87db74d8dad..88c52b8a5c4 100644 --- a/tensorflow/python/autograph/converters/call_trees.py +++ b/tensorflow/python/autograph/converters/call_trees.py @@ -28,7 +28,6 @@ import gast from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import ast_util from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import templates from tensorflow.python.autograph.utils import ag_logging @@ -48,6 +47,46 @@ class _Function(object): set_trace_warned = False +class _ArgTemplateBuilder(object): + """Constructs a tuple representing the positional arguments in a call. + + Example (yes, it's legal Python 3): + + f(*args1, b, *args2, c, d) -> args1 + (b,) + args2 + (c, d) + """ + + def __init__(self): + self._arg_accumulator = [] + self._argspec = [] + self._finalized = False + + def _consume_args(self): + if self._arg_accumulator: + self._argspec.append(gast.Tuple(elts=self._arg_accumulator, ctx=None)) + self._arg_accumulator = [] + + def add_arg(self, a): + self._arg_accumulator.append(a) + + def add_stararg(self, a): + self._consume_args() + self._argspec.append( + gast.Call(gast.Name('tuple', gast.Load(), None), [a], ())) + + def finalize(self): + self._consume_args() + self._finalized = True + + def to_ast(self): + assert self._finalized + if self._argspec: + result = self._argspec[0] + for i in range(1, len(self._argspec)): + result = gast.BinOp(result, gast.Add(), self._argspec[i]) + return result + return gast.Tuple([], None) + + class CallTreeTransformer(converter.Base): """Transforms the call tree by renaming transformed symbols.""" @@ -98,6 +137,32 @@ class CallTreeTransformer(converter.Base): node.body = self.visit_block(node.body) return node + def _args_to_tuple(self, node): + """Ties together all positional and *arg arguments in a single tuple.""" + # TODO(mdan): We could rewrite this to just a call to tuple(). Maybe better? + # For example for + # f(a, b, *args) + # instead of writing: + # (a, b) + args + # just write this? + # tuple(a, b, *args) + builder = _ArgTemplateBuilder() + for a in node.args: + if isinstance(a, gast.Starred): + builder.add_stararg(a.value) + else: + builder.add_arg(a) + builder.finalize() + return builder.to_ast() + + def _kwargs_to_dict(self, node): + """Ties together all keyword and **kwarg arguments in a single dict.""" + if node.keywords: + return gast.Call( + gast.Name('dict', gast.Load(), None), args=(), keywords=node.keywords) + else: + return parser.parse_expression('None') + def visit_Call(self, node): full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) function_context_name = self.state[_Function].context_name @@ -136,51 +201,14 @@ class CallTreeTransformer(converter.Base): not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): return node - func = node.func - - starred_arg = None - normal_args = [] - for a in node.args: - if isinstance(a, gast.Starred): - assert starred_arg is None, 'Multiple *args should be impossible.' - starred_arg = a - else: - normal_args.append(a) - if starred_arg is None: - args = templates.replace_as_expression('(args,)', args=normal_args) - else: - args = templates.replace_as_expression( - '(args,) + tuple(stararg)', - stararg=starred_arg.value, - args=normal_args) - - kwargs_arg = None - normal_keywords = [] - for k in node.keywords: - if k.arg is None: - assert kwargs_arg is None, 'Multiple **kwargs should be impossible.' - kwargs_arg = k - else: - normal_keywords.append(k) - if kwargs_arg is None: - if not normal_keywords: - kwargs = parser.parse_expression('None') - else: - kwargs = ast_util.keywords_to_dict(normal_keywords) - else: - kwargs = templates.replace_as_expression( - 'dict(kwargs, **keywords)', - kwargs=kwargs_arg.value, - keywords=ast_util.keywords_to_dict(normal_keywords)) - template = """ ag__.converted_call(func, args, kwargs, function_ctx) """ new_call = templates.replace_as_expression( template, - func=func, - args=args, - kwargs=kwargs, + func=node.func, + args=self._args_to_tuple(node), + kwargs=self._kwargs_to_dict(node), function_ctx=function_context_name) return new_call diff --git a/tensorflow/python/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py index 6336d380b10..2140224847a 100644 --- a/tensorflow/python/autograph/converters/call_trees_test.py +++ b/tensorflow/python/autograph/converters/call_trees_test.py @@ -1,3 +1,4 @@ +# python3 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -101,7 +102,7 @@ class CallTreesTest(converter_testing.TestCase): }), 12) self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})]) - def test_function_with_kwargs_starargs_only(self): + def test_function_with_starargs_only(self): def f(*args): return sum(args) @@ -115,6 +116,22 @@ class CallTreesTest(converter_testing.TestCase): self.assertEqual(result.test_fn(), 4321) self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)]) + # TODO(b/142586827): Enable this test. + # def test_function_with_starargs_mixed(self): + # + # def f(a, b, c, d): + # return a * 1000 + b * 100 + c * 10 + d + # + # def test_fn(): + # args1 = (1,) + # args2 = [3] + # return f(*args1, 2, *args2, 4) + # + # with self.converted(test_fn, (function_scopes, call_trees), + # {'f': f}) as result: + # self.assertEqual(result.test_fn(), 1234) + # self.assertListEqual(self.dynamic_calls, [((1, 2, 3, 4), None)]) + def test_function_with_kwargs_keywords(self): def test_fn(f, a, b, **kwargs): @@ -125,6 +142,23 @@ class CallTreesTest(converter_testing.TestCase): result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12) self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})]) + # TODO(b/142586827): Enable this test. + # def test_function_with_multiple_kwargs(self): + # + # def test_fn(f, a, b, c, kwargs1, kwargs2): + # return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5 + # + # with self.converted(test_fn, (function_scopes, call_trees), {}) as result: + # self.assertEqual( + # result.test_fn(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4}, + # {'e': 5}), 12) + # self.assertListEqual(self.dynamic_calls, [((1,), { + # 'b': 2, + # 'c': 3, + # 'd': 4, + # 'e': 5 + # })]) + def test_debugger_set_trace(self): tracking_list = [] diff --git a/tensorflow/python/autograph/impl/BUILD b/tensorflow/python/autograph/impl/BUILD index 616cd8a18e9..234695c880d 100644 --- a/tensorflow/python/autograph/impl/BUILD +++ b/tensorflow/python/autograph/impl/BUILD @@ -61,8 +61,8 @@ py_test( deps = [ ":impl", "//tensorflow/python:client_testlib", - "//tensorflow/python/autograph/utils", - "//third_party/py/numpy", + "//tensorflow/python:constant_op", + "//tensorflow/python/autograph/core", ], ) diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 5754249bc91..242c1edc5c3 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -374,7 +374,7 @@ def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): For internal use only. Note: The argument list is optimized for readability of generated code, which - may looks something like this: + may look like this: ag__.converted_call(f, (arg1, arg2), None, fscope) ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope) diff --git a/tensorflow/python/autograph/impl/api_py3_test.py b/tensorflow/python/autograph/impl/api_py3_test.py index 9f8a4b3f31d..df6544928bf 100644 --- a/tensorflow/python/autograph/impl/api_py3_test.py +++ b/tensorflow/python/autograph/impl/api_py3_test.py @@ -56,7 +56,7 @@ class ApiTest(test.TestCase): def no_arg(self, x): return super().plus_three(x) - tc = api.converted_call(TestSubclass, (), {}, options=DEFAULT_RECURSIVE) + tc = api.converted_call(TestSubclass, (), None, options=DEFAULT_RECURSIVE) self.assertEqual(5, tc.no_arg(2))