Internal cleanup: process global and nonlocal in a more consistent fashion during static analysis. Disconnect definitions of symbols between local functions and their parent.

PiperOrigin-RevId: 307630337
Change-Id: I3c5a79b35a52f83c6b32239c43869ec4ebd34380
This commit is contained in:
Dan Moldovan 2020-04-21 10:21:24 -07:00 committed by TensorFlower Gardener
parent 6f9a5289f9
commit e830abc58c
9 changed files with 50 additions and 61 deletions

View File

@ -43,10 +43,8 @@ import collections
import weakref
from enum import Enum
# pylint:disable=g-bad-import-order
import gast
# pylint:enable=g-bad-import-order
import six
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
@ -207,6 +205,18 @@ class GraphVisitor(object):
node: self.init_state(node) for node in self.graph.index.values()
}
def can_ignore(self, node):
"""Returns True if the node can safely be assumed not to touch variables."""
ast_node = node.ast_node
if anno.hasanno(ast_node, anno.Basic.SKIP_PROCESSING):
return True
if six.PY2:
if (isinstance(ast_node, gast.Name) and
ast_node.id in ('None', 'True', 'False')):
return True
return isinstance(ast_node,
(gast.Break, gast.Continue, gast.Raise, gast.Pass))
def _visit_internal(self, mode):
"""Visits the CFG, depth-first."""
assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE)

View File

@ -131,6 +131,7 @@ KNOWN_STRING_CONSTRUCTOR_ERRORS = (
RuntimeError,
StopIteration,
TypeError,
UnboundLocalError,
ValueError,
)

View File

@ -223,6 +223,7 @@ class ActivityAnalyzer(transformer.Base):
def __init__(self, context, parent_scope=None):
super(ActivityAnalyzer, self).__init__(context)
self.allow_skips = False
self.scope = Scope(parent_scope, isolated=True)
# Note: all these flags crucially rely on the respective nodes are
@ -327,8 +328,21 @@ class ActivityAnalyzer(transformer.Base):
return self._process_statement(node)
def visit_Global(self, node):
self._enter_scope(False)
for name in node.names:
self.scope.globals.add(qual_names.QN(name))
qn = qual_names.QN(name)
self.scope.read.add(qn)
self.scope.globals.add(qn)
self._exit_and_record_scope(node)
return node
def visit_Nonlocal(self, node):
self._enter_scope(False)
for name in node.names:
qn = qual_names.QN(name)
self.scope.read.add(qn)
self.scope.bound.add(qn)
self._exit_and_record_scope(node)
return node
def visit_Expr(self, node):

View File

@ -43,7 +43,10 @@ class ActivityAnalyzerTest(activity_test.ActivityAnalyzerTestBase):
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('nonlocal_b', 'c'), ('nonlocal_a',))
self.assertScopeIs(
body_scope, ('nonlocal_a', 'nonlocal_b', 'c'), ('nonlocal_a',))
nonlocal_a_scope = anno.getanno(fn_node.body[0], anno.Static.SCOPE)
self.assertScopeIs(nonlocal_a_scope, ('nonlocal_a',), ())
def test_annotated_assign(self):
b = int

View File

@ -607,9 +607,11 @@ class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('global_b', 'c'), ('global_a',))
self.assertScopeIs(body_scope, ('global_a', 'global_b', 'c'), ('global_a',))
self.assertSetEqual(body_scope.globals, set(
(QN('global_a'), QN('global_b'))))
global_a_scope = anno.getanno(fn_node.body[0], anno.Static.SCOPE)
self.assertScopeIs(global_a_scope, ('global_a',), ())
def test_class_definition_basic(self):

View File

@ -67,12 +67,8 @@ class Analyzer(cfg.GraphVisitor):
live_in = gen | (live_out - kill)
else:
# Nodes that don't have a scope annotation are assumed not to touch any
# symbols.
# This Name node below is a literal name, e.g. False
assert isinstance(node.ast_node,
(gast.Name, gast.Continue, gast.Break, gast.Pass,
gast.Global, gast.Nonlocal)), type(node.ast_node)
assert self.can_ignore(node), (node.ast_node, node)
live_out = set()
for n in node.next:
live_out |= self.in_[n]
@ -105,6 +101,7 @@ class WholeTreeAnalyzer(transformer.Base):
def __init__(self, source_info, graphs):
super(WholeTreeAnalyzer, self).__init__(source_info)
self.allow_skips = False
self.graphs = graphs
self.current_analyzer = None
self.analyzers = {}

View File

@ -34,9 +34,7 @@ import gast
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import annos
class Definition(object):
@ -137,8 +135,12 @@ class Analyzer(cfg.GraphVisitor):
# their ids are used in equality checks.
if node not in self.gen_map:
node_symbols = {}
# Every modification receives a definition.
for s in node_scope.modified:
# Every binding operation (assign, nonlocal, global, etc.) counts as a
# definition, with the exception of del, which only deletes without
# creating a new variable.
newly_defined = ((node_scope.bound | node_scope.globals) -
node_scope.deleted)
for s in newly_defined:
def_ = self._definition_factory()
node_symbols[s] = def_
# Every param receives a definition. Params are not necessarily
@ -153,41 +155,16 @@ class Analyzer(cfg.GraphVisitor):
kill = node_scope.modified | node_scope.deleted
defs_out = gen | (defs_in - kill)
elif isinstance(node.ast_node, (gast.Global, gast.Nonlocal)):
# Special case for global and nonlocal: they generate a definition,
# but are not tracked by activity analysis.
if node not in self.gen_map:
node_symbols = {}
kill = set()
for s in node.ast_node.names:
qn = qual_names.QN(s)
# TODO(mdan): If definitions exist, should we preserve those instead?
# Incoming definitions may be present when this is a local function.
# In that case, the definitions of the nonlocal symbol from the
# enclosing function are available here. See self.extra_in.
kill.add(qn)
def_ = self._definition_factory()
node_symbols[qn] = def_
self.gen_map[node] = _NodeState(node_symbols)
gen = self.gen_map[node]
defs_out = gen | (defs_in - kill)
else:
# Nodes that don't have a scope annotation are assumed not to touch any
# symbols.
# This Name node below is a literal name, e.g. False
# This can also happen if activity.py forgot to annotate the node with a
# scope object.
assert isinstance(node.ast_node,
(gast.Name, gast.Break, gast.Continue, gast.Raise,
gast.Pass)), (node.ast_node, node)
assert self.can_ignore(node), (node.ast_node, node)
defs_out = defs_in
self.in_[node] = defs_in
self.out[node] = defs_out
# TODO(mdan): Move this to the superclass?
return prev_defs_out != defs_out
@ -205,6 +182,7 @@ class TreeAnnotator(transformer.Base):
def __init__(self, source_info, graphs, definition_factory):
super(TreeAnnotator, self).__init__(source_info)
self.allow_skips = False
self.definition_factory = definition_factory
self.graphs = graphs
self.current_analyzer = None
@ -214,28 +192,11 @@ class TreeAnnotator(transformer.Base):
parent_analyzer = self.current_analyzer
subgraph = self.graphs[node]
# Preorder tree processing:
# 1. if this is a child function, the parent was already analyzed and it
# has the proper state value for the subgraph's entry
# 2. analyze the current function body
# 2. recursively walk the subtree; child functions will be processed
analyzer = Analyzer(subgraph, self.definition_factory)
if parent_analyzer is not None:
# Wire the state between the two subgraphs' analyzers.
parent_out_state = parent_analyzer.out[parent_analyzer.graph.index[node]]
# Exception: symbols modified in the child function are local to it
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
parent_out_state -= body_scope.modified
analyzer.extra_in[node.args] = parent_out_state
# Complete the analysis for the local function and annotate its body.
analyzer.visit_forward()
# Recursively process any remaining subfunctions.
self.current_analyzer = analyzer
# Note: not visiting name, decorator_list and returns because they don't
# apply to this analysis.
# TODO(mdan): Should we still process the function name?
node.args = self.visit(node.args)
node.body = self.visit_block(node.body)
self.current_analyzer = parent_analyzer

View File

@ -78,7 +78,7 @@ class ReachingDefinitionsAnalyzerTest(
self.assertSameDef(local_body[1].test, local_body[2].value.elts[0])
self.assertHasDefinedIn(local_body[1], ('a', 'b', 'local_fn'))
self.assertHasDefinedIn(local_body[1], ('a', 'b'))
if __name__ == '__main__':

View File

@ -254,7 +254,8 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase):
self.assertHasDefs(fn_body[2].value, 2)
inner_fn_body = fn_body[1].body[1].body
self.assertSameDef(inner_fn_body[0].value, def_of_a_in_if)
def_of_a_in_foo = inner_fn_body[0].value
self.assertHasDefs(def_of_a_in_foo, 0)
def test_nested_functions_isolation(self):