Correctly track arguments and their defaults in activity analysis. Do so consistently between functions and lambdas. Fix unit tests.

PiperOrigin-RevId: 307664175
Change-Id: Iaf062745e1f6040bf32897bcf89266c7202c235b
This commit is contained in:
Dan Moldovan 2020-04-21 13:00:02 -07:00 committed by TensorFlower Gardener
parent e225b8c8ca
commit 81734a896e
3 changed files with 219 additions and 48 deletions
tensorflow/python/autograph/pyct/static_analysis

View File

@ -488,9 +488,6 @@ class ActivityAnalyzer(transformer.Base):
def visit_GeneratorExp(self, node):
return self._process_comprehension(node)
def visit_arguments(self, node):
return self._process_statement(node)
def visit_ClassDef(self, node):
with self.state[_FunctionOrClass] as fn:
fn.node = node
@ -510,6 +507,24 @@ class ActivityAnalyzer(transformer.Base):
self._exit_scope()
return node
def _visit_node_list(self, nodes):
return [(None if n is None else self.visit(n)) for n in nodes]
def _visit_arg_defaults(self, node):
node.args.kw_defaults = self._visit_node_list(node.args.kw_defaults)
node.args.defaults = self._visit_node_list(node.args.defaults)
return node
def _visit_arg_declarations(self, node):
node.args.posonlyargs = self._visit_node_list(node.args.posonlyargs)
node.args.args = self._visit_node_list(node.args.args)
if node.args.vararg is not None:
node.args.vararg = self.visit(node.args.vararg)
node.args.kwonlyargs = self._visit_node_list(node.args.kwonlyargs)
if node.args.kwarg is not None:
node.args.kwarg = self.visit(node.args.kwarg)
return node
def visit_FunctionDef(self, node):
with self.state[_FunctionOrClass] as fn:
fn.node = node
@ -517,6 +532,11 @@ class ActivityAnalyzer(transformer.Base):
# of its name, along with the usage of any decorator accompanying it.
self._enter_scope(False)
node.decorator_list = self.visit_block(node.decorator_list)
# Arg defaults affect the defining context - they are being evaluated
# at definition.
node = self._visit_arg_defaults(node)
function_name = qual_names.QN(node.name)
self.scope.modified.add(function_name)
self.scope.bound.add(function_name)
@ -524,7 +544,15 @@ class ActivityAnalyzer(transformer.Base):
# A separate Scope tracks the actual function definition.
self._enter_scope(True)
node.args = self.visit(node.args)
# Keep a separate scope for the arguments node, which is used in the CFG.
self._enter_scope(False)
# Arg declarations only affect the function itself, and have no effect
# in the defining context whatsoever.
node = self._visit_arg_declarations(node)
self._exit_and_record_scope(node.args)
# Track the body separately. This is for compatibility reasons, it may not
# be strictly needed.
@ -532,16 +560,35 @@ class ActivityAnalyzer(transformer.Base):
node.body = self.visit_block(node.body)
self._exit_and_record_scope(node, NodeAnno.BODY_SCOPE)
self._exit_scope()
self._exit_and_record_scope(node, NodeAnno.ARGS_AND_BODY_SCOPE)
return node
def visit_Lambda(self, node):
# Lambda nodes are treated in roughly the same way as FunctionDef nodes.
with self.state[_FunctionOrClass] as fn:
fn.node = node
self._enter_scope(True)
node = self.generic_visit(node)
# The Lambda node itself has a Scope object that tracks the creation
# of its name, along with the usage of any decorator accompanying it.
self._enter_scope(False)
node = self._visit_arg_defaults(node)
self._exit_and_record_scope(node)
# A separate Scope tracks the actual function definition.
self._enter_scope(True)
# Keep a separate scope for the arguments node, which is used in the CFG.
self._enter_scope(False)
node = self._visit_arg_declarations(node)
self._exit_and_record_scope(node.args)
# Track the body separately. This is for compatibility reasons, it may not
# be strictly needed.
# TODO(mdan): Do remove it, it's confusing.
self._enter_scope(False)
node.body = self.visit(node.body)
self._exit_and_record_scope(node, NodeAnno.BODY_SCOPE)
self._exit_and_record_scope(node, NodeAnno.ARGS_AND_BODY_SCOPE)
return node
def visit_With(self, node):

View File

@ -373,17 +373,48 @@ class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
y = x * x
return y
b = a
for i in a:
c = b
b -= f(i)
return b, c
return f(a)
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'f'), ('f',))
fn_def_node = node.body[0]
scope = anno.getanno(fn_def_node, anno.Static.SCOPE)
self.assertScopeIs(scope, (), ('f'))
scope = anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('x', 'y'), ('y',))
scope = anno.getanno(fn_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
self.assertScopeIs(scope, ('x', 'y'), ('y',))
self.assertSymbolSetsAre(('x', 'y'), scope.bound, 'BOUND')
def test_nested_function_arg_defaults(self):
def test_fn(a):
def f(x=a):
y = x * x
return y
return f(a)
node, _ = self._parse_and_analyze(test_fn)
fn_def_node = node.body[0]
self.assertScopeIs(
anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',))
anno.getanno(fn_def_node, anno.Static.SCOPE), ('a',), ('f',))
scope = anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('x', 'y'), ('y',))
scope = anno.getanno(fn_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
self.assertScopeIs(scope, ('x', 'y'), ('y',))
self.assertSymbolSetsAre(('x', 'y'), scope.bound, 'BOUND')
def test_constructor_attributes(self):
@ -482,64 +513,154 @@ class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
self.assertScopeIs(
anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('foo', 'x'), ())
def test_params(self):
def test_fn(a, b): # pylint: disable=unused-argument
return b
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('b',), ())
self.assertScopeIs(body_scope.parent, ('b',), ())
args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE)
self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params')
def test_lambda_captures_reads(self):
def test_lambda(self):
def test_fn(a, b):
return lambda: a + b
return lambda: (a + b)
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('a', 'b'), ())
# Nothing local to the lambda is tracked.
self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
def test_lambda_params_are_isolated(self):
fn_node = node
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'b'), ())
lam_def_node = node.body[0].value
scope = anno.getanno(lam_def_node, anno.Static.SCOPE)
self.assertScopeIs(scope, (), ())
scope = anno.getanno(lam_def_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'b'), ())
scope = anno.getanno(lam_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'b'), ())
self.assertSymbolSetsAre((), scope.bound, 'BOUND')
scope = anno.getanno(lam_def_node.args, anno.Static.SCOPE)
self.assertSymbolSetsAre((), scope.params.keys(), 'lambda params')
def test_lambda_params_args(self):
def test_fn(a, b): # pylint: disable=unused-argument
return lambda a: a + b
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('b',), ())
self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
# Note: `a` in `a + b` is not "read" here because it's hidden by the `a`
# argument.
self.assertScopeIs(scope, ('b',), ())
lam_def_node = node.body[0].value
scope = anno.getanno(lam_def_node, anno.Static.SCOPE)
self.assertScopeIs(scope, (), ())
scope = anno.getanno(lam_def_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'b'), ())
scope = anno.getanno(lam_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'b'), ())
self.assertSymbolSetsAre(('a',), scope.bound, 'BOUND')
scope = anno.getanno(lam_def_node.args, anno.Static.SCOPE)
self.assertSymbolSetsAre(('a',), scope.params.keys(), 'lambda params')
def test_lambda_params_arg_defaults(self):
def test_fn(a, b, c): # pylint: disable=unused-argument
return lambda b=c: a + b
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
# Note: `b` is not "read" here because it's hidden by the argument.
self.assertScopeIs(scope, ('a', 'c'), ())
lam_def_node = node.body[0].value
scope = anno.getanno(lam_def_node, anno.Static.SCOPE)
self.assertScopeIs(scope, ('c',), ())
scope = anno.getanno(lam_def_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'b'), ())
scope = anno.getanno(lam_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'b'), ())
self.assertSymbolSetsAre(('b',), scope.bound, 'BOUND')
scope = anno.getanno(lam_def_node.args, anno.Static.SCOPE)
self.assertSymbolSetsAre(('b',), scope.params.keys(), 'lambda params')
def test_lambda_complex(self):
def test_fn(a, b, c, d): # pylint: disable=unused-argument
a = (lambda a, b, c: a + b + c)(d, 1, 2) + b
def test_fn(a, b, c, d, e): # pylint: disable=unused-argument
a = (lambda a, b, c=e: a + b + c)(d, 1, 2) + b
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('b', 'd'), ('a',))
self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('d', 'b', 'e'), ('a',))
lam_def_node = node.body[0].value.left.func
scope = anno.getanno(lam_def_node, anno.Static.SCOPE)
self.assertScopeIs(scope, ('e',), ())
scope = anno.getanno(lam_def_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'b', 'c'), ())
scope = anno.getanno(lam_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'b', 'c'), ())
self.assertSymbolSetsAre(('a', 'b', 'c'), scope.bound, 'BOUND')
scope = anno.getanno(lam_def_node.args, anno.Static.SCOPE)
self.assertSymbolSetsAre(
('a', 'b', 'c'), scope.params.keys(), 'lambda params')
def test_lambda_nested(self):
def test_fn(a, b, c, d, e): # pylint: disable=unused-argument
a = lambda a, b: d(lambda b: a + b + c) # pylint: disable=undefined-variable
def test_fn(a, b, c, d, e, f): # pylint: disable=unused-argument
a = lambda a, b: d(lambda b=f: a + b + c) # pylint: disable=undefined-variable
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('c', 'd'), ('a',))
self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('d', 'c', 'f'), ('a',))
outer_lam_def = node.body[0].value
scope = anno.getanno(outer_lam_def, anno.Static.SCOPE)
self.assertScopeIs(scope, (), ())
scope = anno.getanno(outer_lam_def, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('d', 'f', 'a', 'c'), ())
scope = anno.getanno(outer_lam_def, NodeAnno.ARGS_AND_BODY_SCOPE)
self.assertScopeIs(scope, ('d', 'f', 'a', 'c'), ())
self.assertSymbolSetsAre(('a', 'b'), scope.bound, 'BOUND')
scope = anno.getanno(outer_lam_def.args, anno.Static.SCOPE)
self.assertSymbolSetsAre(('a', 'b'), scope.params.keys(), 'lambda params')
inner_lam_def = outer_lam_def.body.args[0]
scope = anno.getanno(inner_lam_def, anno.Static.SCOPE)
self.assertScopeIs(scope, ('f',), ())
scope = anno.getanno(inner_lam_def, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'b', 'c'), ())
scope = anno.getanno(inner_lam_def, NodeAnno.ARGS_AND_BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'b', 'c'), ())
self.assertSymbolSetsAre(('b',), scope.bound, 'BOUND')
scope = anno.getanno(inner_lam_def.args, anno.Static.SCOPE)
self.assertSymbolSetsAre(('b',), scope.params.keys(), 'lambda params')
def test_comprehension_targets_are_isolated(self):

View File

@ -48,6 +48,9 @@ class NodeAnno(NoValue):
ARGS_SCOPE = 'The scope for the argument list of a function call.'
COND_SCOPE = 'The scope for the test node of a conditional statement.'
ITERATE_SCOPE = 'The scope for the iterate assignment of a for loop.'
ARGS_AND_BODY_SCOPE = (
'The scope for the main body of a function or lambda, including its'
' arguments.')
BODY_SCOPE = (
'The scope for the main body of a statement (True branch for if '
'statements, main body for loops).')