Internal cleanup: process global and nonlocal in a more consistent fashion during static analysis. Disconnect definitions of symbols between local functions and their parent.
PiperOrigin-RevId: 307630337 Change-Id: I3c5a79b35a52f83c6b32239c43869ec4ebd34380
This commit is contained in:
parent
6f9a5289f9
commit
e830abc58c
@ -43,10 +43,8 @@ import collections
|
|||||||
import weakref
|
import weakref
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
# pylint:disable=g-bad-import-order
|
|
||||||
|
|
||||||
import gast
|
import gast
|
||||||
# pylint:enable=g-bad-import-order
|
import six
|
||||||
|
|
||||||
from tensorflow.python.autograph.pyct import anno
|
from tensorflow.python.autograph.pyct import anno
|
||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
@ -207,6 +205,18 @@ class GraphVisitor(object):
|
|||||||
node: self.init_state(node) for node in self.graph.index.values()
|
node: self.init_state(node) for node in self.graph.index.values()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def can_ignore(self, node):
|
||||||
|
"""Returns True if the node can safely be assumed not to touch variables."""
|
||||||
|
ast_node = node.ast_node
|
||||||
|
if anno.hasanno(ast_node, anno.Basic.SKIP_PROCESSING):
|
||||||
|
return True
|
||||||
|
if six.PY2:
|
||||||
|
if (isinstance(ast_node, gast.Name) and
|
||||||
|
ast_node.id in ('None', 'True', 'False')):
|
||||||
|
return True
|
||||||
|
return isinstance(ast_node,
|
||||||
|
(gast.Break, gast.Continue, gast.Raise, gast.Pass))
|
||||||
|
|
||||||
def _visit_internal(self, mode):
|
def _visit_internal(self, mode):
|
||||||
"""Visits the CFG, depth-first."""
|
"""Visits the CFG, depth-first."""
|
||||||
assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE)
|
assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE)
|
||||||
|
@ -131,6 +131,7 @@ KNOWN_STRING_CONSTRUCTOR_ERRORS = (
|
|||||||
RuntimeError,
|
RuntimeError,
|
||||||
StopIteration,
|
StopIteration,
|
||||||
TypeError,
|
TypeError,
|
||||||
|
UnboundLocalError,
|
||||||
ValueError,
|
ValueError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -223,6 +223,7 @@ class ActivityAnalyzer(transformer.Base):
|
|||||||
|
|
||||||
def __init__(self, context, parent_scope=None):
|
def __init__(self, context, parent_scope=None):
|
||||||
super(ActivityAnalyzer, self).__init__(context)
|
super(ActivityAnalyzer, self).__init__(context)
|
||||||
|
self.allow_skips = False
|
||||||
self.scope = Scope(parent_scope, isolated=True)
|
self.scope = Scope(parent_scope, isolated=True)
|
||||||
|
|
||||||
# Note: all these flags crucially rely on the respective nodes are
|
# Note: all these flags crucially rely on the respective nodes are
|
||||||
@ -327,8 +328,21 @@ class ActivityAnalyzer(transformer.Base):
|
|||||||
return self._process_statement(node)
|
return self._process_statement(node)
|
||||||
|
|
||||||
def visit_Global(self, node):
|
def visit_Global(self, node):
|
||||||
|
self._enter_scope(False)
|
||||||
for name in node.names:
|
for name in node.names:
|
||||||
self.scope.globals.add(qual_names.QN(name))
|
qn = qual_names.QN(name)
|
||||||
|
self.scope.read.add(qn)
|
||||||
|
self.scope.globals.add(qn)
|
||||||
|
self._exit_and_record_scope(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_Nonlocal(self, node):
|
||||||
|
self._enter_scope(False)
|
||||||
|
for name in node.names:
|
||||||
|
qn = qual_names.QN(name)
|
||||||
|
self.scope.read.add(qn)
|
||||||
|
self.scope.bound.add(qn)
|
||||||
|
self._exit_and_record_scope(node)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def visit_Expr(self, node):
|
def visit_Expr(self, node):
|
||||||
|
@ -43,7 +43,10 @@ class ActivityAnalyzerTest(activity_test.ActivityAnalyzerTestBase):
|
|||||||
node, _ = self._parse_and_analyze(test_fn)
|
node, _ = self._parse_and_analyze(test_fn)
|
||||||
fn_node = node
|
fn_node = node
|
||||||
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||||
self.assertScopeIs(body_scope, ('nonlocal_b', 'c'), ('nonlocal_a',))
|
self.assertScopeIs(
|
||||||
|
body_scope, ('nonlocal_a', 'nonlocal_b', 'c'), ('nonlocal_a',))
|
||||||
|
nonlocal_a_scope = anno.getanno(fn_node.body[0], anno.Static.SCOPE)
|
||||||
|
self.assertScopeIs(nonlocal_a_scope, ('nonlocal_a',), ())
|
||||||
|
|
||||||
def test_annotated_assign(self):
|
def test_annotated_assign(self):
|
||||||
b = int
|
b = int
|
||||||
|
@ -607,9 +607,11 @@ class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
|
|||||||
node, _ = self._parse_and_analyze(test_fn)
|
node, _ = self._parse_and_analyze(test_fn)
|
||||||
fn_node = node
|
fn_node = node
|
||||||
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||||
self.assertScopeIs(body_scope, ('global_b', 'c'), ('global_a',))
|
self.assertScopeIs(body_scope, ('global_a', 'global_b', 'c'), ('global_a',))
|
||||||
self.assertSetEqual(body_scope.globals, set(
|
self.assertSetEqual(body_scope.globals, set(
|
||||||
(QN('global_a'), QN('global_b'))))
|
(QN('global_a'), QN('global_b'))))
|
||||||
|
global_a_scope = anno.getanno(fn_node.body[0], anno.Static.SCOPE)
|
||||||
|
self.assertScopeIs(global_a_scope, ('global_a',), ())
|
||||||
|
|
||||||
def test_class_definition_basic(self):
|
def test_class_definition_basic(self):
|
||||||
|
|
||||||
|
@ -67,12 +67,8 @@ class Analyzer(cfg.GraphVisitor):
|
|||||||
live_in = gen | (live_out - kill)
|
live_in = gen | (live_out - kill)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Nodes that don't have a scope annotation are assumed not to touch any
|
assert self.can_ignore(node), (node.ast_node, node)
|
||||||
# symbols.
|
|
||||||
# This Name node below is a literal name, e.g. False
|
|
||||||
assert isinstance(node.ast_node,
|
|
||||||
(gast.Name, gast.Continue, gast.Break, gast.Pass,
|
|
||||||
gast.Global, gast.Nonlocal)), type(node.ast_node)
|
|
||||||
live_out = set()
|
live_out = set()
|
||||||
for n in node.next:
|
for n in node.next:
|
||||||
live_out |= self.in_[n]
|
live_out |= self.in_[n]
|
||||||
@ -105,6 +101,7 @@ class WholeTreeAnalyzer(transformer.Base):
|
|||||||
|
|
||||||
def __init__(self, source_info, graphs):
|
def __init__(self, source_info, graphs):
|
||||||
super(WholeTreeAnalyzer, self).__init__(source_info)
|
super(WholeTreeAnalyzer, self).__init__(source_info)
|
||||||
|
self.allow_skips = False
|
||||||
self.graphs = graphs
|
self.graphs = graphs
|
||||||
self.current_analyzer = None
|
self.current_analyzer = None
|
||||||
self.analyzers = {}
|
self.analyzers = {}
|
||||||
|
@ -34,9 +34,7 @@ import gast
|
|||||||
|
|
||||||
from tensorflow.python.autograph.pyct import anno
|
from tensorflow.python.autograph.pyct import anno
|
||||||
from tensorflow.python.autograph.pyct import cfg
|
from tensorflow.python.autograph.pyct import cfg
|
||||||
from tensorflow.python.autograph.pyct import qual_names
|
|
||||||
from tensorflow.python.autograph.pyct import transformer
|
from tensorflow.python.autograph.pyct import transformer
|
||||||
from tensorflow.python.autograph.pyct.static_analysis import annos
|
|
||||||
|
|
||||||
|
|
||||||
class Definition(object):
|
class Definition(object):
|
||||||
@ -137,8 +135,12 @@ class Analyzer(cfg.GraphVisitor):
|
|||||||
# their ids are used in equality checks.
|
# their ids are used in equality checks.
|
||||||
if node not in self.gen_map:
|
if node not in self.gen_map:
|
||||||
node_symbols = {}
|
node_symbols = {}
|
||||||
# Every modification receives a definition.
|
# Every binding operation (assign, nonlocal, global, etc.) counts as a
|
||||||
for s in node_scope.modified:
|
# definition, with the exception of del, which only deletes without
|
||||||
|
# creating a new variable.
|
||||||
|
newly_defined = ((node_scope.bound | node_scope.globals) -
|
||||||
|
node_scope.deleted)
|
||||||
|
for s in newly_defined:
|
||||||
def_ = self._definition_factory()
|
def_ = self._definition_factory()
|
||||||
node_symbols[s] = def_
|
node_symbols[s] = def_
|
||||||
# Every param receives a definition. Params are not necessarily
|
# Every param receives a definition. Params are not necessarily
|
||||||
@ -153,41 +155,16 @@ class Analyzer(cfg.GraphVisitor):
|
|||||||
kill = node_scope.modified | node_scope.deleted
|
kill = node_scope.modified | node_scope.deleted
|
||||||
defs_out = gen | (defs_in - kill)
|
defs_out = gen | (defs_in - kill)
|
||||||
|
|
||||||
elif isinstance(node.ast_node, (gast.Global, gast.Nonlocal)):
|
|
||||||
# Special case for global and nonlocal: they generate a definition,
|
|
||||||
# but are not tracked by activity analysis.
|
|
||||||
if node not in self.gen_map:
|
|
||||||
node_symbols = {}
|
|
||||||
kill = set()
|
|
||||||
for s in node.ast_node.names:
|
|
||||||
qn = qual_names.QN(s)
|
|
||||||
# TODO(mdan): If definitions exist, should we preserve those instead?
|
|
||||||
# Incoming definitions may be present when this is a local function.
|
|
||||||
# In that case, the definitions of the nonlocal symbol from the
|
|
||||||
# enclosing function are available here. See self.extra_in.
|
|
||||||
kill.add(qn)
|
|
||||||
def_ = self._definition_factory()
|
|
||||||
node_symbols[qn] = def_
|
|
||||||
self.gen_map[node] = _NodeState(node_symbols)
|
|
||||||
|
|
||||||
gen = self.gen_map[node]
|
gen = self.gen_map[node]
|
||||||
defs_out = gen | (defs_in - kill)
|
defs_out = gen | (defs_in - kill)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Nodes that don't have a scope annotation are assumed not to touch any
|
assert self.can_ignore(node), (node.ast_node, node)
|
||||||
# symbols.
|
|
||||||
# This Name node below is a literal name, e.g. False
|
|
||||||
# This can also happen if activity.py forgot to annotate the node with a
|
|
||||||
# scope object.
|
|
||||||
assert isinstance(node.ast_node,
|
|
||||||
(gast.Name, gast.Break, gast.Continue, gast.Raise,
|
|
||||||
gast.Pass)), (node.ast_node, node)
|
|
||||||
defs_out = defs_in
|
defs_out = defs_in
|
||||||
|
|
||||||
self.in_[node] = defs_in
|
self.in_[node] = defs_in
|
||||||
self.out[node] = defs_out
|
self.out[node] = defs_out
|
||||||
|
|
||||||
# TODO(mdan): Move this to the superclass?
|
|
||||||
return prev_defs_out != defs_out
|
return prev_defs_out != defs_out
|
||||||
|
|
||||||
|
|
||||||
@ -205,6 +182,7 @@ class TreeAnnotator(transformer.Base):
|
|||||||
|
|
||||||
def __init__(self, source_info, graphs, definition_factory):
|
def __init__(self, source_info, graphs, definition_factory):
|
||||||
super(TreeAnnotator, self).__init__(source_info)
|
super(TreeAnnotator, self).__init__(source_info)
|
||||||
|
self.allow_skips = False
|
||||||
self.definition_factory = definition_factory
|
self.definition_factory = definition_factory
|
||||||
self.graphs = graphs
|
self.graphs = graphs
|
||||||
self.current_analyzer = None
|
self.current_analyzer = None
|
||||||
@ -214,28 +192,11 @@ class TreeAnnotator(transformer.Base):
|
|||||||
parent_analyzer = self.current_analyzer
|
parent_analyzer = self.current_analyzer
|
||||||
subgraph = self.graphs[node]
|
subgraph = self.graphs[node]
|
||||||
|
|
||||||
# Preorder tree processing:
|
|
||||||
# 1. if this is a child function, the parent was already analyzed and it
|
|
||||||
# has the proper state value for the subgraph's entry
|
|
||||||
# 2. analyze the current function body
|
|
||||||
# 2. recursively walk the subtree; child functions will be processed
|
|
||||||
analyzer = Analyzer(subgraph, self.definition_factory)
|
analyzer = Analyzer(subgraph, self.definition_factory)
|
||||||
if parent_analyzer is not None:
|
|
||||||
# Wire the state between the two subgraphs' analyzers.
|
|
||||||
parent_out_state = parent_analyzer.out[parent_analyzer.graph.index[node]]
|
|
||||||
# Exception: symbols modified in the child function are local to it
|
|
||||||
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
|
||||||
parent_out_state -= body_scope.modified
|
|
||||||
analyzer.extra_in[node.args] = parent_out_state
|
|
||||||
|
|
||||||
# Complete the analysis for the local function and annotate its body.
|
|
||||||
analyzer.visit_forward()
|
analyzer.visit_forward()
|
||||||
|
|
||||||
# Recursively process any remaining subfunctions.
|
# Recursively process any remaining subfunctions.
|
||||||
self.current_analyzer = analyzer
|
self.current_analyzer = analyzer
|
||||||
# Note: not visiting name, decorator_list and returns because they don't
|
|
||||||
# apply to this analysis.
|
|
||||||
# TODO(mdan): Should we still process the function name?
|
|
||||||
node.args = self.visit(node.args)
|
node.args = self.visit(node.args)
|
||||||
node.body = self.visit_block(node.body)
|
node.body = self.visit_block(node.body)
|
||||||
self.current_analyzer = parent_analyzer
|
self.current_analyzer = parent_analyzer
|
||||||
|
@ -78,7 +78,7 @@ class ReachingDefinitionsAnalyzerTest(
|
|||||||
|
|
||||||
self.assertSameDef(local_body[1].test, local_body[2].value.elts[0])
|
self.assertSameDef(local_body[1].test, local_body[2].value.elts[0])
|
||||||
|
|
||||||
self.assertHasDefinedIn(local_body[1], ('a', 'b', 'local_fn'))
|
self.assertHasDefinedIn(local_body[1], ('a', 'b'))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -254,7 +254,8 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase):
|
|||||||
self.assertHasDefs(fn_body[2].value, 2)
|
self.assertHasDefs(fn_body[2].value, 2)
|
||||||
|
|
||||||
inner_fn_body = fn_body[1].body[1].body
|
inner_fn_body = fn_body[1].body[1].body
|
||||||
self.assertSameDef(inner_fn_body[0].value, def_of_a_in_if)
|
def_of_a_in_foo = inner_fn_body[0].value
|
||||||
|
self.assertHasDefs(def_of_a_in_foo, 0)
|
||||||
|
|
||||||
def test_nested_functions_isolation(self):
|
def test_nested_functions_isolation(self):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user