Handle cases of multiple *args and **kwargs, which are legal in Python3.

PiperOrigin-RevId: 274371522
This commit is contained in:
Dan Moldovan 2019-10-12 14:08:58 -07:00 committed by TensorFlower Gardener
parent deadb15945
commit 8ba59743a4
6 changed files with 117 additions and 49 deletions

View File

@ -42,6 +42,7 @@ py_library(
"//tensorflow/python/autograph/lang",
"//tensorflow/python/autograph/pyct",
"//tensorflow/python/autograph/pyct/static_analysis",
"//tensorflow/python/autograph/utils",
"@gast_archive//:gast",
],
)
@ -82,13 +83,18 @@ py_test(
py_test(
name = "call_trees_test",
srcs = ["call_trees_test.py"],
srcs_version = "PY2AND3",
tags = ["no_windows"],
python_version = "PY3",
srcs_version = "PY3",
tags = [
"no_oss_py2",
"no_pip",
"no_windows",
"nopip",
],
deps = [
":converters",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/core:test_lib",
"//tensorflow/python/autograph/impl",
],
)

View File

@ -28,7 +28,6 @@ import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.utils import ag_logging
@ -48,6 +47,46 @@ class _Function(object):
set_trace_warned = False
class _ArgTemplateBuilder(object):
"""Constructs a tuple representing the positional arguments in a call.
Example (yes, it's legal Python 3):
f(*args1, b, *args2, c, d) -> args1 + (b,) + args2 + (c, d)
"""
def __init__(self):
self._arg_accumulator = []
self._argspec = []
self._finalized = False
def _consume_args(self):
if self._arg_accumulator:
self._argspec.append(gast.Tuple(elts=self._arg_accumulator, ctx=None))
self._arg_accumulator = []
def add_arg(self, a):
self._arg_accumulator.append(a)
def add_stararg(self, a):
self._consume_args()
self._argspec.append(
gast.Call(gast.Name('tuple', gast.Load(), None), [a], ()))
def finalize(self):
self._consume_args()
self._finalized = True
def to_ast(self):
assert self._finalized
if self._argspec:
result = self._argspec[0]
for i in range(1, len(self._argspec)):
result = gast.BinOp(result, gast.Add(), self._argspec[i])
return result
return gast.Tuple([], None)
class CallTreeTransformer(converter.Base):
"""Transforms the call tree by renaming transformed symbols."""
@ -98,6 +137,32 @@ class CallTreeTransformer(converter.Base):
node.body = self.visit_block(node.body)
return node
def _args_to_tuple(self, node):
"""Ties together all positional and *arg arguments in a single tuple."""
# TODO(mdan): We could rewrite this to just a call to tuple(). Maybe better?
# For example for
# f(a, b, *args)
# instead of writing:
# (a, b) + args
# just write this?
# tuple(a, b, *args)
builder = _ArgTemplateBuilder()
for a in node.args:
if isinstance(a, gast.Starred):
builder.add_stararg(a.value)
else:
builder.add_arg(a)
builder.finalize()
return builder.to_ast()
def _kwargs_to_dict(self, node):
"""Ties together all keyword and **kwarg arguments in a single dict."""
if node.keywords:
return gast.Call(
gast.Name('dict', gast.Load(), None), args=(), keywords=node.keywords)
else:
return parser.parse_expression('None')
def visit_Call(self, node):
full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
function_context_name = self.state[_Function].context_name
@ -136,51 +201,14 @@ class CallTreeTransformer(converter.Base):
not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
return node
func = node.func
starred_arg = None
normal_args = []
for a in node.args:
if isinstance(a, gast.Starred):
assert starred_arg is None, 'Multiple *args should be impossible.'
starred_arg = a
else:
normal_args.append(a)
if starred_arg is None:
args = templates.replace_as_expression('(args,)', args=normal_args)
else:
args = templates.replace_as_expression(
'(args,) + tuple(stararg)',
stararg=starred_arg.value,
args=normal_args)
kwargs_arg = None
normal_keywords = []
for k in node.keywords:
if k.arg is None:
assert kwargs_arg is None, 'Multiple **kwargs should be impossible.'
kwargs_arg = k
else:
normal_keywords.append(k)
if kwargs_arg is None:
if not normal_keywords:
kwargs = parser.parse_expression('None')
else:
kwargs = ast_util.keywords_to_dict(normal_keywords)
else:
kwargs = templates.replace_as_expression(
'dict(kwargs, **keywords)',
kwargs=kwargs_arg.value,
keywords=ast_util.keywords_to_dict(normal_keywords))
template = """
ag__.converted_call(func, args, kwargs, function_ctx)
"""
new_call = templates.replace_as_expression(
template,
func=func,
args=args,
kwargs=kwargs,
func=node.func,
args=self._args_to_tuple(node),
kwargs=self._kwargs_to_dict(node),
function_ctx=function_context_name)
return new_call

View File

@ -1,3 +1,4 @@
# python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -101,7 +102,7 @@ class CallTreesTest(converter_testing.TestCase):
}), 12)
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})])
def test_function_with_kwargs_starargs_only(self):
def test_function_with_starargs_only(self):
def f(*args):
return sum(args)
@ -115,6 +116,22 @@ class CallTreesTest(converter_testing.TestCase):
self.assertEqual(result.test_fn(), 4321)
self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
# TODO(b/142586827): Enable this test.
# def test_function_with_starargs_mixed(self):
#
# def f(a, b, c, d):
# return a * 1000 + b * 100 + c * 10 + d
#
# def test_fn():
# args1 = (1,)
# args2 = [3]
# return f(*args1, 2, *args2, 4)
#
# with self.converted(test_fn, (function_scopes, call_trees),
# {'f': f}) as result:
# self.assertEqual(result.test_fn(), 1234)
# self.assertListEqual(self.dynamic_calls, [((1, 2, 3, 4), None)])
def test_function_with_kwargs_keywords(self):
def test_fn(f, a, b, **kwargs):
@ -125,6 +142,23 @@ class CallTreesTest(converter_testing.TestCase):
result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
# TODO(b/142586827): Enable this test.
# def test_function_with_multiple_kwargs(self):
#
# def test_fn(f, a, b, c, kwargs1, kwargs2):
# return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5
#
# with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
# self.assertEqual(
# result.test_fn(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4},
# {'e': 5}), 12)
# self.assertListEqual(self.dynamic_calls, [((1,), {
# 'b': 2,
# 'c': 3,
# 'd': 4,
# 'e': 5
# })])
def test_debugger_set_trace(self):
tracking_list = []

View File

@ -61,8 +61,8 @@ py_test(
deps = [
":impl",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/utils",
"//third_party/py/numpy",
"//tensorflow/python:constant_op",
"//tensorflow/python/autograph/core",
],
)

View File

@ -374,7 +374,7 @@ def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
For internal use only.
Note: The argument list is optimized for readability of generated code, which
may looks something like this:
may look like this:
ag__.converted_call(f, (arg1, arg2), None, fscope)
ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)

View File

@ -56,7 +56,7 @@ class ApiTest(test.TestCase):
def no_arg(self, x):
return super().plus_three(x)
tc = api.converted_call(TestSubclass, (), {}, options=DEFAULT_RECURSIVE)
tc = api.converted_call(TestSubclass, (), None, options=DEFAULT_RECURSIVE)
self.assertEqual(5, tc.no_arg(2))