Handle cases of multiple *args and **kwargs, which are legal in Python3.
PiperOrigin-RevId: 274371522
This commit is contained in:
parent
deadb15945
commit
8ba59743a4
@ -42,6 +42,7 @@ py_library(
|
|||||||
"//tensorflow/python/autograph/lang",
|
"//tensorflow/python/autograph/lang",
|
||||||
"//tensorflow/python/autograph/pyct",
|
"//tensorflow/python/autograph/pyct",
|
||||||
"//tensorflow/python/autograph/pyct/static_analysis",
|
"//tensorflow/python/autograph/pyct/static_analysis",
|
||||||
|
"//tensorflow/python/autograph/utils",
|
||||||
"@gast_archive//:gast",
|
"@gast_archive//:gast",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -82,13 +83,18 @@ py_test(
|
|||||||
py_test(
|
py_test(
|
||||||
name = "call_trees_test",
|
name = "call_trees_test",
|
||||||
srcs = ["call_trees_test.py"],
|
srcs = ["call_trees_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
python_version = "PY3",
|
||||||
tags = ["no_windows"],
|
srcs_version = "PY3",
|
||||||
|
tags = [
|
||||||
|
"no_oss_py2",
|
||||||
|
"no_pip",
|
||||||
|
"no_windows",
|
||||||
|
"nopip",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":converters",
|
":converters",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python/autograph/core:test_lib",
|
"//tensorflow/python/autograph/core:test_lib",
|
||||||
"//tensorflow/python/autograph/impl",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,7 +28,6 @@ import gast
|
|||||||
|
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
from tensorflow.python.autograph.pyct import anno
|
from tensorflow.python.autograph.pyct import anno
|
||||||
from tensorflow.python.autograph.pyct import ast_util
|
|
||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
from tensorflow.python.autograph.pyct import templates
|
from tensorflow.python.autograph.pyct import templates
|
||||||
from tensorflow.python.autograph.utils import ag_logging
|
from tensorflow.python.autograph.utils import ag_logging
|
||||||
@ -48,6 +47,46 @@ class _Function(object):
|
|||||||
set_trace_warned = False
|
set_trace_warned = False
|
||||||
|
|
||||||
|
|
||||||
|
class _ArgTemplateBuilder(object):
|
||||||
|
"""Constructs a tuple representing the positional arguments in a call.
|
||||||
|
|
||||||
|
Example (yes, it's legal Python 3):
|
||||||
|
|
||||||
|
f(*args1, b, *args2, c, d) -> args1 + (b,) + args2 + (c, d)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._arg_accumulator = []
|
||||||
|
self._argspec = []
|
||||||
|
self._finalized = False
|
||||||
|
|
||||||
|
def _consume_args(self):
|
||||||
|
if self._arg_accumulator:
|
||||||
|
self._argspec.append(gast.Tuple(elts=self._arg_accumulator, ctx=None))
|
||||||
|
self._arg_accumulator = []
|
||||||
|
|
||||||
|
def add_arg(self, a):
|
||||||
|
self._arg_accumulator.append(a)
|
||||||
|
|
||||||
|
def add_stararg(self, a):
|
||||||
|
self._consume_args()
|
||||||
|
self._argspec.append(
|
||||||
|
gast.Call(gast.Name('tuple', gast.Load(), None), [a], ()))
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
self._consume_args()
|
||||||
|
self._finalized = True
|
||||||
|
|
||||||
|
def to_ast(self):
|
||||||
|
assert self._finalized
|
||||||
|
if self._argspec:
|
||||||
|
result = self._argspec[0]
|
||||||
|
for i in range(1, len(self._argspec)):
|
||||||
|
result = gast.BinOp(result, gast.Add(), self._argspec[i])
|
||||||
|
return result
|
||||||
|
return gast.Tuple([], None)
|
||||||
|
|
||||||
|
|
||||||
class CallTreeTransformer(converter.Base):
|
class CallTreeTransformer(converter.Base):
|
||||||
"""Transforms the call tree by renaming transformed symbols."""
|
"""Transforms the call tree by renaming transformed symbols."""
|
||||||
|
|
||||||
@ -98,6 +137,32 @@ class CallTreeTransformer(converter.Base):
|
|||||||
node.body = self.visit_block(node.body)
|
node.body = self.visit_block(node.body)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
def _args_to_tuple(self, node):
|
||||||
|
"""Ties together all positional and *arg arguments in a single tuple."""
|
||||||
|
# TODO(mdan): We could rewrite this to just a call to tuple(). Maybe better?
|
||||||
|
# For example for
|
||||||
|
# f(a, b, *args)
|
||||||
|
# instead of writing:
|
||||||
|
# (a, b) + args
|
||||||
|
# just write this?
|
||||||
|
# tuple(a, b, *args)
|
||||||
|
builder = _ArgTemplateBuilder()
|
||||||
|
for a in node.args:
|
||||||
|
if isinstance(a, gast.Starred):
|
||||||
|
builder.add_stararg(a.value)
|
||||||
|
else:
|
||||||
|
builder.add_arg(a)
|
||||||
|
builder.finalize()
|
||||||
|
return builder.to_ast()
|
||||||
|
|
||||||
|
def _kwargs_to_dict(self, node):
|
||||||
|
"""Ties together all keyword and **kwarg arguments in a single dict."""
|
||||||
|
if node.keywords:
|
||||||
|
return gast.Call(
|
||||||
|
gast.Name('dict', gast.Load(), None), args=(), keywords=node.keywords)
|
||||||
|
else:
|
||||||
|
return parser.parse_expression('None')
|
||||||
|
|
||||||
def visit_Call(self, node):
|
def visit_Call(self, node):
|
||||||
full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
|
full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
|
||||||
function_context_name = self.state[_Function].context_name
|
function_context_name = self.state[_Function].context_name
|
||||||
@ -136,51 +201,14 @@ class CallTreeTransformer(converter.Base):
|
|||||||
not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
|
not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
|
||||||
return node
|
return node
|
||||||
|
|
||||||
func = node.func
|
|
||||||
|
|
||||||
starred_arg = None
|
|
||||||
normal_args = []
|
|
||||||
for a in node.args:
|
|
||||||
if isinstance(a, gast.Starred):
|
|
||||||
assert starred_arg is None, 'Multiple *args should be impossible.'
|
|
||||||
starred_arg = a
|
|
||||||
else:
|
|
||||||
normal_args.append(a)
|
|
||||||
if starred_arg is None:
|
|
||||||
args = templates.replace_as_expression('(args,)', args=normal_args)
|
|
||||||
else:
|
|
||||||
args = templates.replace_as_expression(
|
|
||||||
'(args,) + tuple(stararg)',
|
|
||||||
stararg=starred_arg.value,
|
|
||||||
args=normal_args)
|
|
||||||
|
|
||||||
kwargs_arg = None
|
|
||||||
normal_keywords = []
|
|
||||||
for k in node.keywords:
|
|
||||||
if k.arg is None:
|
|
||||||
assert kwargs_arg is None, 'Multiple **kwargs should be impossible.'
|
|
||||||
kwargs_arg = k
|
|
||||||
else:
|
|
||||||
normal_keywords.append(k)
|
|
||||||
if kwargs_arg is None:
|
|
||||||
if not normal_keywords:
|
|
||||||
kwargs = parser.parse_expression('None')
|
|
||||||
else:
|
|
||||||
kwargs = ast_util.keywords_to_dict(normal_keywords)
|
|
||||||
else:
|
|
||||||
kwargs = templates.replace_as_expression(
|
|
||||||
'dict(kwargs, **keywords)',
|
|
||||||
kwargs=kwargs_arg.value,
|
|
||||||
keywords=ast_util.keywords_to_dict(normal_keywords))
|
|
||||||
|
|
||||||
template = """
|
template = """
|
||||||
ag__.converted_call(func, args, kwargs, function_ctx)
|
ag__.converted_call(func, args, kwargs, function_ctx)
|
||||||
"""
|
"""
|
||||||
new_call = templates.replace_as_expression(
|
new_call = templates.replace_as_expression(
|
||||||
template,
|
template,
|
||||||
func=func,
|
func=node.func,
|
||||||
args=args,
|
args=self._args_to_tuple(node),
|
||||||
kwargs=kwargs,
|
kwargs=self._kwargs_to_dict(node),
|
||||||
function_ctx=function_context_name)
|
function_ctx=function_context_name)
|
||||||
|
|
||||||
return new_call
|
return new_call
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# python3
|
||||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -101,7 +102,7 @@ class CallTreesTest(converter_testing.TestCase):
|
|||||||
}), 12)
|
}), 12)
|
||||||
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})])
|
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})])
|
||||||
|
|
||||||
def test_function_with_kwargs_starargs_only(self):
|
def test_function_with_starargs_only(self):
|
||||||
|
|
||||||
def f(*args):
|
def f(*args):
|
||||||
return sum(args)
|
return sum(args)
|
||||||
@ -115,6 +116,22 @@ class CallTreesTest(converter_testing.TestCase):
|
|||||||
self.assertEqual(result.test_fn(), 4321)
|
self.assertEqual(result.test_fn(), 4321)
|
||||||
self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
|
self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
|
||||||
|
|
||||||
|
# TODO(b/142586827): Enable this test.
|
||||||
|
# def test_function_with_starargs_mixed(self):
|
||||||
|
#
|
||||||
|
# def f(a, b, c, d):
|
||||||
|
# return a * 1000 + b * 100 + c * 10 + d
|
||||||
|
#
|
||||||
|
# def test_fn():
|
||||||
|
# args1 = (1,)
|
||||||
|
# args2 = [3]
|
||||||
|
# return f(*args1, 2, *args2, 4)
|
||||||
|
#
|
||||||
|
# with self.converted(test_fn, (function_scopes, call_trees),
|
||||||
|
# {'f': f}) as result:
|
||||||
|
# self.assertEqual(result.test_fn(), 1234)
|
||||||
|
# self.assertListEqual(self.dynamic_calls, [((1, 2, 3, 4), None)])
|
||||||
|
|
||||||
def test_function_with_kwargs_keywords(self):
|
def test_function_with_kwargs_keywords(self):
|
||||||
|
|
||||||
def test_fn(f, a, b, **kwargs):
|
def test_fn(f, a, b, **kwargs):
|
||||||
@ -125,6 +142,23 @@ class CallTreesTest(converter_testing.TestCase):
|
|||||||
result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
|
result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
|
||||||
self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
|
self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
|
||||||
|
|
||||||
|
# TODO(b/142586827): Enable this test.
|
||||||
|
# def test_function_with_multiple_kwargs(self):
|
||||||
|
#
|
||||||
|
# def test_fn(f, a, b, c, kwargs1, kwargs2):
|
||||||
|
# return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5
|
||||||
|
#
|
||||||
|
# with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
|
||||||
|
# self.assertEqual(
|
||||||
|
# result.test_fn(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4},
|
||||||
|
# {'e': 5}), 12)
|
||||||
|
# self.assertListEqual(self.dynamic_calls, [((1,), {
|
||||||
|
# 'b': 2,
|
||||||
|
# 'c': 3,
|
||||||
|
# 'd': 4,
|
||||||
|
# 'e': 5
|
||||||
|
# })])
|
||||||
|
|
||||||
def test_debugger_set_trace(self):
|
def test_debugger_set_trace(self):
|
||||||
|
|
||||||
tracking_list = []
|
tracking_list = []
|
||||||
|
@ -61,8 +61,8 @@ py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":impl",
|
":impl",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python/autograph/utils",
|
"//tensorflow/python:constant_op",
|
||||||
"//third_party/py/numpy",
|
"//tensorflow/python/autograph/core",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -374,7 +374,7 @@ def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
|
|||||||
For internal use only.
|
For internal use only.
|
||||||
|
|
||||||
Note: The argument list is optimized for readability of generated code, which
|
Note: The argument list is optimized for readability of generated code, which
|
||||||
may looks something like this:
|
may look like this:
|
||||||
|
|
||||||
ag__.converted_call(f, (arg1, arg2), None, fscope)
|
ag__.converted_call(f, (arg1, arg2), None, fscope)
|
||||||
ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
|
ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
|
||||||
|
@ -56,7 +56,7 @@ class ApiTest(test.TestCase):
|
|||||||
def no_arg(self, x):
|
def no_arg(self, x):
|
||||||
return super().plus_three(x)
|
return super().plus_three(x)
|
||||||
|
|
||||||
tc = api.converted_call(TestSubclass, (), {}, options=DEFAULT_RECURSIVE)
|
tc = api.converted_call(TestSubclass, (), None, options=DEFAULT_RECURSIVE)
|
||||||
|
|
||||||
self.assertEqual(5, tc.no_arg(2))
|
self.assertEqual(5, tc.no_arg(2))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user