Correctly track arguments and their defaults in activity analysis. Do so consistently between functions and lambdas. Fix unit tests.
PiperOrigin-RevId: 307664175 Change-Id: Iaf062745e1f6040bf32897bcf89266c7202c235b
This commit is contained in:
parent
e225b8c8ca
commit
81734a896e
tensorflow/python/autograph/pyct/static_analysis
@ -488,9 +488,6 @@ class ActivityAnalyzer(transformer.Base):
|
||||
def visit_GeneratorExp(self, node):
|
||||
return self._process_comprehension(node)
|
||||
|
||||
def visit_arguments(self, node):
|
||||
return self._process_statement(node)
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
with self.state[_FunctionOrClass] as fn:
|
||||
fn.node = node
|
||||
@ -510,6 +507,24 @@ class ActivityAnalyzer(transformer.Base):
|
||||
self._exit_scope()
|
||||
return node
|
||||
|
||||
def _visit_node_list(self, nodes):
|
||||
return [(None if n is None else self.visit(n)) for n in nodes]
|
||||
|
||||
def _visit_arg_defaults(self, node):
|
||||
node.args.kw_defaults = self._visit_node_list(node.args.kw_defaults)
|
||||
node.args.defaults = self._visit_node_list(node.args.defaults)
|
||||
return node
|
||||
|
||||
def _visit_arg_declarations(self, node):
|
||||
node.args.posonlyargs = self._visit_node_list(node.args.posonlyargs)
|
||||
node.args.args = self._visit_node_list(node.args.args)
|
||||
if node.args.vararg is not None:
|
||||
node.args.vararg = self.visit(node.args.vararg)
|
||||
node.args.kwonlyargs = self._visit_node_list(node.args.kwonlyargs)
|
||||
if node.args.kwarg is not None:
|
||||
node.args.kwarg = self.visit(node.args.kwarg)
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
with self.state[_FunctionOrClass] as fn:
|
||||
fn.node = node
|
||||
@ -517,6 +532,11 @@ class ActivityAnalyzer(transformer.Base):
|
||||
# of its name, along with the usage of any decorator accompanying it.
|
||||
self._enter_scope(False)
|
||||
node.decorator_list = self.visit_block(node.decorator_list)
|
||||
|
||||
# Arg defaults affect the defining context - they are being evaluated
|
||||
# at definition.
|
||||
node = self._visit_arg_defaults(node)
|
||||
|
||||
function_name = qual_names.QN(node.name)
|
||||
self.scope.modified.add(function_name)
|
||||
self.scope.bound.add(function_name)
|
||||
@ -524,7 +544,15 @@ class ActivityAnalyzer(transformer.Base):
|
||||
|
||||
# A separate Scope tracks the actual function definition.
|
||||
self._enter_scope(True)
|
||||
node.args = self.visit(node.args)
|
||||
|
||||
# Keep a separate scope for the arguments node, which is used in the CFG.
|
||||
self._enter_scope(False)
|
||||
|
||||
# Arg declarations only affect the function itself, and have no effect
|
||||
# in the defining context whatsoever.
|
||||
node = self._visit_arg_declarations(node)
|
||||
|
||||
self._exit_and_record_scope(node.args)
|
||||
|
||||
# Track the body separately. This is for compatibility reasons, it may not
|
||||
# be strictly needed.
|
||||
@ -532,16 +560,35 @@ class ActivityAnalyzer(transformer.Base):
|
||||
node.body = self.visit_block(node.body)
|
||||
self._exit_and_record_scope(node, NodeAnno.BODY_SCOPE)
|
||||
|
||||
self._exit_scope()
|
||||
self._exit_and_record_scope(node, NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
return node
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
# Lambda nodes are treated in roughly the same way as FunctionDef nodes.
|
||||
with self.state[_FunctionOrClass] as fn:
|
||||
fn.node = node
|
||||
self._enter_scope(True)
|
||||
node = self.generic_visit(node)
|
||||
# The Lambda node itself has a Scope object that tracks the creation
|
||||
# of its name, along with the usage of any decorator accompanying it.
|
||||
self._enter_scope(False)
|
||||
node = self._visit_arg_defaults(node)
|
||||
self._exit_and_record_scope(node)
|
||||
|
||||
# A separate Scope tracks the actual function definition.
|
||||
self._enter_scope(True)
|
||||
|
||||
# Keep a separate scope for the arguments node, which is used in the CFG.
|
||||
self._enter_scope(False)
|
||||
node = self._visit_arg_declarations(node)
|
||||
self._exit_and_record_scope(node.args)
|
||||
|
||||
# Track the body separately. This is for compatibility reasons, it may not
|
||||
# be strictly needed.
|
||||
# TODO(mdan): Do remove it, it's confusing.
|
||||
self._enter_scope(False)
|
||||
node.body = self.visit(node.body)
|
||||
self._exit_and_record_scope(node, NodeAnno.BODY_SCOPE)
|
||||
|
||||
self._exit_and_record_scope(node, NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
return node
|
||||
|
||||
def visit_With(self, node):
|
||||
|
@ -373,17 +373,48 @@ class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
|
||||
y = x * x
|
||||
return y
|
||||
|
||||
b = a
|
||||
for i in a:
|
||||
c = b
|
||||
b -= f(i)
|
||||
return b, c
|
||||
return f(a)
|
||||
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
|
||||
fn_node = node
|
||||
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'f'), ('f',))
|
||||
|
||||
fn_def_node = node.body[0]
|
||||
|
||||
scope = anno.getanno(fn_def_node, anno.Static.SCOPE)
|
||||
self.assertScopeIs(scope, (), ('f'))
|
||||
|
||||
scope = anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('x', 'y'), ('y',))
|
||||
|
||||
scope = anno.getanno(fn_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('x', 'y'), ('y',))
|
||||
self.assertSymbolSetsAre(('x', 'y'), scope.bound, 'BOUND')
|
||||
|
||||
def test_nested_function_arg_defaults(self):
|
||||
|
||||
def test_fn(a):
|
||||
|
||||
def f(x=a):
|
||||
y = x * x
|
||||
return y
|
||||
|
||||
return f(a)
|
||||
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
fn_def_node = node.body[0]
|
||||
|
||||
self.assertScopeIs(
|
||||
anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',))
|
||||
anno.getanno(fn_def_node, anno.Static.SCOPE), ('a',), ('f',))
|
||||
|
||||
scope = anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('x', 'y'), ('y',))
|
||||
|
||||
scope = anno.getanno(fn_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('x', 'y'), ('y',))
|
||||
self.assertSymbolSetsAre(('x', 'y'), scope.bound, 'BOUND')
|
||||
|
||||
def test_constructor_attributes(self):
|
||||
|
||||
@ -482,64 +513,154 @@ class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
|
||||
self.assertScopeIs(
|
||||
anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('foo', 'x'), ())
|
||||
|
||||
def test_params(self):
|
||||
|
||||
def test_fn(a, b): # pylint: disable=unused-argument
|
||||
return b
|
||||
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
fn_node = node
|
||||
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(body_scope, ('b',), ())
|
||||
self.assertScopeIs(body_scope.parent, ('b',), ())
|
||||
|
||||
args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE)
|
||||
self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params')
|
||||
|
||||
def test_lambda_captures_reads(self):
|
||||
def test_lambda(self):
|
||||
|
||||
def test_fn(a, b):
|
||||
return lambda: a + b
|
||||
return lambda: (a + b)
|
||||
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
fn_node = node
|
||||
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(body_scope, ('a', 'b'), ())
|
||||
# Nothing local to the lambda is tracked.
|
||||
self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
|
||||
|
||||
def test_lambda_params_are_isolated(self):
|
||||
fn_node = node
|
||||
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'b'), ())
|
||||
|
||||
lam_def_node = node.body[0].value
|
||||
|
||||
scope = anno.getanno(lam_def_node, anno.Static.SCOPE)
|
||||
self.assertScopeIs(scope, (), ())
|
||||
|
||||
scope = anno.getanno(lam_def_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'b'), ())
|
||||
|
||||
scope = anno.getanno(lam_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'b'), ())
|
||||
self.assertSymbolSetsAre((), scope.bound, 'BOUND')
|
||||
|
||||
scope = anno.getanno(lam_def_node.args, anno.Static.SCOPE)
|
||||
self.assertSymbolSetsAre((), scope.params.keys(), 'lambda params')
|
||||
|
||||
def test_lambda_params_args(self):
|
||||
|
||||
def test_fn(a, b): # pylint: disable=unused-argument
|
||||
return lambda a: a + b
|
||||
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
|
||||
fn_node = node
|
||||
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(body_scope, ('b',), ())
|
||||
self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
|
||||
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
# Note: `a` in `a + b` is not "read" here because it's hidden by the `a`
|
||||
# argument.
|
||||
self.assertScopeIs(scope, ('b',), ())
|
||||
|
||||
lam_def_node = node.body[0].value
|
||||
|
||||
scope = anno.getanno(lam_def_node, anno.Static.SCOPE)
|
||||
self.assertScopeIs(scope, (), ())
|
||||
|
||||
scope = anno.getanno(lam_def_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'b'), ())
|
||||
|
||||
scope = anno.getanno(lam_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'b'), ())
|
||||
self.assertSymbolSetsAre(('a',), scope.bound, 'BOUND')
|
||||
|
||||
scope = anno.getanno(lam_def_node.args, anno.Static.SCOPE)
|
||||
self.assertSymbolSetsAre(('a',), scope.params.keys(), 'lambda params')
|
||||
|
||||
def test_lambda_params_arg_defaults(self):
|
||||
|
||||
def test_fn(a, b, c): # pylint: disable=unused-argument
|
||||
return lambda b=c: a + b
|
||||
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
|
||||
fn_node = node
|
||||
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
# Note: `b` is not "read" here because it's hidden by the argument.
|
||||
self.assertScopeIs(scope, ('a', 'c'), ())
|
||||
|
||||
lam_def_node = node.body[0].value
|
||||
|
||||
scope = anno.getanno(lam_def_node, anno.Static.SCOPE)
|
||||
self.assertScopeIs(scope, ('c',), ())
|
||||
|
||||
scope = anno.getanno(lam_def_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'b'), ())
|
||||
|
||||
scope = anno.getanno(lam_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'b'), ())
|
||||
self.assertSymbolSetsAre(('b',), scope.bound, 'BOUND')
|
||||
|
||||
scope = anno.getanno(lam_def_node.args, anno.Static.SCOPE)
|
||||
self.assertSymbolSetsAre(('b',), scope.params.keys(), 'lambda params')
|
||||
|
||||
def test_lambda_complex(self):
|
||||
|
||||
def test_fn(a, b, c, d): # pylint: disable=unused-argument
|
||||
a = (lambda a, b, c: a + b + c)(d, 1, 2) + b
|
||||
def test_fn(a, b, c, d, e): # pylint: disable=unused-argument
|
||||
a = (lambda a, b, c=e: a + b + c)(d, 1, 2) + b
|
||||
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
|
||||
fn_node = node
|
||||
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(body_scope, ('b', 'd'), ('a',))
|
||||
self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
|
||||
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('d', 'b', 'e'), ('a',))
|
||||
|
||||
lam_def_node = node.body[0].value.left.func
|
||||
|
||||
scope = anno.getanno(lam_def_node, anno.Static.SCOPE)
|
||||
self.assertScopeIs(scope, ('e',), ())
|
||||
|
||||
scope = anno.getanno(lam_def_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'b', 'c'), ())
|
||||
|
||||
scope = anno.getanno(lam_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'b', 'c'), ())
|
||||
self.assertSymbolSetsAre(('a', 'b', 'c'), scope.bound, 'BOUND')
|
||||
|
||||
scope = anno.getanno(lam_def_node.args, anno.Static.SCOPE)
|
||||
self.assertSymbolSetsAre(
|
||||
('a', 'b', 'c'), scope.params.keys(), 'lambda params')
|
||||
|
||||
def test_lambda_nested(self):
|
||||
|
||||
def test_fn(a, b, c, d, e): # pylint: disable=unused-argument
|
||||
a = lambda a, b: d(lambda b: a + b + c) # pylint: disable=undefined-variable
|
||||
def test_fn(a, b, c, d, e, f): # pylint: disable=unused-argument
|
||||
a = lambda a, b: d(lambda b=f: a + b + c) # pylint: disable=undefined-variable
|
||||
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
|
||||
fn_node = node
|
||||
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(body_scope, ('c', 'd'), ('a',))
|
||||
self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
|
||||
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('d', 'c', 'f'), ('a',))
|
||||
|
||||
outer_lam_def = node.body[0].value
|
||||
|
||||
scope = anno.getanno(outer_lam_def, anno.Static.SCOPE)
|
||||
self.assertScopeIs(scope, (), ())
|
||||
|
||||
scope = anno.getanno(outer_lam_def, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('d', 'f', 'a', 'c'), ())
|
||||
|
||||
scope = anno.getanno(outer_lam_def, NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('d', 'f', 'a', 'c'), ())
|
||||
self.assertSymbolSetsAre(('a', 'b'), scope.bound, 'BOUND')
|
||||
|
||||
scope = anno.getanno(outer_lam_def.args, anno.Static.SCOPE)
|
||||
self.assertSymbolSetsAre(('a', 'b'), scope.params.keys(), 'lambda params')
|
||||
|
||||
inner_lam_def = outer_lam_def.body.args[0]
|
||||
|
||||
scope = anno.getanno(inner_lam_def, anno.Static.SCOPE)
|
||||
self.assertScopeIs(scope, ('f',), ())
|
||||
|
||||
scope = anno.getanno(inner_lam_def, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'b', 'c'), ())
|
||||
|
||||
scope = anno.getanno(inner_lam_def, NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'b', 'c'), ())
|
||||
self.assertSymbolSetsAre(('b',), scope.bound, 'BOUND')
|
||||
|
||||
scope = anno.getanno(inner_lam_def.args, anno.Static.SCOPE)
|
||||
self.assertSymbolSetsAre(('b',), scope.params.keys(), 'lambda params')
|
||||
|
||||
def test_comprehension_targets_are_isolated(self):
|
||||
|
||||
|
@ -48,6 +48,9 @@ class NodeAnno(NoValue):
|
||||
ARGS_SCOPE = 'The scope for the argument list of a function call.'
|
||||
COND_SCOPE = 'The scope for the test node of a conditional statement.'
|
||||
ITERATE_SCOPE = 'The scope for the iterate assignment of a for loop.'
|
||||
ARGS_AND_BODY_SCOPE = (
|
||||
'The scope for the main body of a function or lambda, including its'
|
||||
' arguments.')
|
||||
BODY_SCOPE = (
|
||||
'The scope for the main body of a statement (True branch for if '
|
||||
'statements, main body for loops).')
|
||||
|
Loading…
Reference in New Issue
Block a user