Handle cases of multiple *args and **kwargs, which are legal in Python3.
PiperOrigin-RevId: 274371522
This commit is contained in:
parent
deadb15945
commit
8ba59743a4
@ -42,6 +42,7 @@ py_library(
|
||||
"//tensorflow/python/autograph/lang",
|
||||
"//tensorflow/python/autograph/pyct",
|
||||
"//tensorflow/python/autograph/pyct/static_analysis",
|
||||
"//tensorflow/python/autograph/utils",
|
||||
"@gast_archive//:gast",
|
||||
],
|
||||
)
|
||||
@ -82,13 +83,18 @@ py_test(
|
||||
py_test(
|
||||
name = "call_trees_test",
|
||||
srcs = ["call_trees_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
tags = [
|
||||
"no_oss_py2",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
"nopip",
|
||||
],
|
||||
deps = [
|
||||
":converters",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/autograph/core:test_lib",
|
||||
"//tensorflow/python/autograph/impl",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -28,7 +28,6 @@ import gast
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.utils import ag_logging
|
||||
@ -48,6 +47,46 @@ class _Function(object):
|
||||
set_trace_warned = False
|
||||
|
||||
|
||||
class _ArgTemplateBuilder(object):
|
||||
"""Constructs a tuple representing the positional arguments in a call.
|
||||
|
||||
Example (yes, it's legal Python 3):
|
||||
|
||||
f(*args1, b, *args2, c, d) -> args1 + (b,) + args2 + (c, d)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._arg_accumulator = []
|
||||
self._argspec = []
|
||||
self._finalized = False
|
||||
|
||||
def _consume_args(self):
|
||||
if self._arg_accumulator:
|
||||
self._argspec.append(gast.Tuple(elts=self._arg_accumulator, ctx=None))
|
||||
self._arg_accumulator = []
|
||||
|
||||
def add_arg(self, a):
|
||||
self._arg_accumulator.append(a)
|
||||
|
||||
def add_stararg(self, a):
|
||||
self._consume_args()
|
||||
self._argspec.append(
|
||||
gast.Call(gast.Name('tuple', gast.Load(), None), [a], ()))
|
||||
|
||||
def finalize(self):
|
||||
self._consume_args()
|
||||
self._finalized = True
|
||||
|
||||
def to_ast(self):
|
||||
assert self._finalized
|
||||
if self._argspec:
|
||||
result = self._argspec[0]
|
||||
for i in range(1, len(self._argspec)):
|
||||
result = gast.BinOp(result, gast.Add(), self._argspec[i])
|
||||
return result
|
||||
return gast.Tuple([], None)
|
||||
|
||||
|
||||
class CallTreeTransformer(converter.Base):
|
||||
"""Transforms the call tree by renaming transformed symbols."""
|
||||
|
||||
@ -98,6 +137,32 @@ class CallTreeTransformer(converter.Base):
|
||||
node.body = self.visit_block(node.body)
|
||||
return node
|
||||
|
||||
def _args_to_tuple(self, node):
|
||||
"""Ties together all positional and *arg arguments in a single tuple."""
|
||||
# TODO(mdan): We could rewrite this to just a call to tuple(). Maybe better?
|
||||
# For example for
|
||||
# f(a, b, *args)
|
||||
# instead of writing:
|
||||
# (a, b) + args
|
||||
# just write this?
|
||||
# tuple(a, b, *args)
|
||||
builder = _ArgTemplateBuilder()
|
||||
for a in node.args:
|
||||
if isinstance(a, gast.Starred):
|
||||
builder.add_stararg(a.value)
|
||||
else:
|
||||
builder.add_arg(a)
|
||||
builder.finalize()
|
||||
return builder.to_ast()
|
||||
|
||||
def _kwargs_to_dict(self, node):
|
||||
"""Ties together all keyword and **kwarg arguments in a single dict."""
|
||||
if node.keywords:
|
||||
return gast.Call(
|
||||
gast.Name('dict', gast.Load(), None), args=(), keywords=node.keywords)
|
||||
else:
|
||||
return parser.parse_expression('None')
|
||||
|
||||
def visit_Call(self, node):
|
||||
full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
|
||||
function_context_name = self.state[_Function].context_name
|
||||
@ -136,51 +201,14 @@ class CallTreeTransformer(converter.Base):
|
||||
not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
|
||||
return node
|
||||
|
||||
func = node.func
|
||||
|
||||
starred_arg = None
|
||||
normal_args = []
|
||||
for a in node.args:
|
||||
if isinstance(a, gast.Starred):
|
||||
assert starred_arg is None, 'Multiple *args should be impossible.'
|
||||
starred_arg = a
|
||||
else:
|
||||
normal_args.append(a)
|
||||
if starred_arg is None:
|
||||
args = templates.replace_as_expression('(args,)', args=normal_args)
|
||||
else:
|
||||
args = templates.replace_as_expression(
|
||||
'(args,) + tuple(stararg)',
|
||||
stararg=starred_arg.value,
|
||||
args=normal_args)
|
||||
|
||||
kwargs_arg = None
|
||||
normal_keywords = []
|
||||
for k in node.keywords:
|
||||
if k.arg is None:
|
||||
assert kwargs_arg is None, 'Multiple **kwargs should be impossible.'
|
||||
kwargs_arg = k
|
||||
else:
|
||||
normal_keywords.append(k)
|
||||
if kwargs_arg is None:
|
||||
if not normal_keywords:
|
||||
kwargs = parser.parse_expression('None')
|
||||
else:
|
||||
kwargs = ast_util.keywords_to_dict(normal_keywords)
|
||||
else:
|
||||
kwargs = templates.replace_as_expression(
|
||||
'dict(kwargs, **keywords)',
|
||||
kwargs=kwargs_arg.value,
|
||||
keywords=ast_util.keywords_to_dict(normal_keywords))
|
||||
|
||||
template = """
|
||||
ag__.converted_call(func, args, kwargs, function_ctx)
|
||||
"""
|
||||
new_call = templates.replace_as_expression(
|
||||
template,
|
||||
func=func,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
func=node.func,
|
||||
args=self._args_to_tuple(node),
|
||||
kwargs=self._kwargs_to_dict(node),
|
||||
function_ctx=function_context_name)
|
||||
|
||||
return new_call
|
||||
|
@ -1,3 +1,4 @@
|
||||
# python3
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -101,7 +102,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
}), 12)
|
||||
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})])
|
||||
|
||||
def test_function_with_kwargs_starargs_only(self):
|
||||
def test_function_with_starargs_only(self):
|
||||
|
||||
def f(*args):
|
||||
return sum(args)
|
||||
@ -115,6 +116,22 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
self.assertEqual(result.test_fn(), 4321)
|
||||
self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
|
||||
|
||||
# TODO(b/142586827): Enable this test.
|
||||
# def test_function_with_starargs_mixed(self):
|
||||
#
|
||||
# def f(a, b, c, d):
|
||||
# return a * 1000 + b * 100 + c * 10 + d
|
||||
#
|
||||
# def test_fn():
|
||||
# args1 = (1,)
|
||||
# args2 = [3]
|
||||
# return f(*args1, 2, *args2, 4)
|
||||
#
|
||||
# with self.converted(test_fn, (function_scopes, call_trees),
|
||||
# {'f': f}) as result:
|
||||
# self.assertEqual(result.test_fn(), 1234)
|
||||
# self.assertListEqual(self.dynamic_calls, [((1, 2, 3, 4), None)])
|
||||
|
||||
def test_function_with_kwargs_keywords(self):
|
||||
|
||||
def test_fn(f, a, b, **kwargs):
|
||||
@ -125,6 +142,23 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
|
||||
|
||||
# TODO(b/142586827): Enable this test.
|
||||
# def test_function_with_multiple_kwargs(self):
|
||||
#
|
||||
# def test_fn(f, a, b, c, kwargs1, kwargs2):
|
||||
# return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5
|
||||
#
|
||||
# with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
|
||||
# self.assertEqual(
|
||||
# result.test_fn(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4},
|
||||
# {'e': 5}), 12)
|
||||
# self.assertListEqual(self.dynamic_calls, [((1,), {
|
||||
# 'b': 2,
|
||||
# 'c': 3,
|
||||
# 'd': 4,
|
||||
# 'e': 5
|
||||
# })])
|
||||
|
||||
def test_debugger_set_trace(self):
|
||||
|
||||
tracking_list = []
|
||||
|
@ -61,8 +61,8 @@ py_test(
|
||||
deps = [
|
||||
":impl",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/autograph/utils",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python/autograph/core",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -374,7 +374,7 @@ def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
|
||||
For internal use only.
|
||||
|
||||
Note: The argument list is optimized for readability of generated code, which
|
||||
may looks something like this:
|
||||
may look like this:
|
||||
|
||||
ag__.converted_call(f, (arg1, arg2), None, fscope)
|
||||
ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
|
||||
|
@ -56,7 +56,7 @@ class ApiTest(test.TestCase):
|
||||
def no_arg(self, x):
|
||||
return super().plus_three(x)
|
||||
|
||||
tc = api.converted_call(TestSubclass, (), {}, options=DEFAULT_RECURSIVE)
|
||||
tc = api.converted_call(TestSubclass, (), None, options=DEFAULT_RECURSIVE)
|
||||
|
||||
self.assertEqual(5, tc.no_arg(2))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user