Internal cleanup: process global and nonlocal in a more consistent fashion during static analysis. Disconnect definitions of symbols between local functions and their parent.
PiperOrigin-RevId: 307630337 Change-Id: I3c5a79b35a52f83c6b32239c43869ec4ebd34380
This commit is contained in:
parent
6f9a5289f9
commit
e830abc58c
@ -43,10 +43,8 @@ import collections
|
||||
import weakref
|
||||
from enum import Enum
|
||||
|
||||
# pylint:disable=g-bad-import-order
|
||||
|
||||
import gast
|
||||
# pylint:enable=g-bad-import-order
|
||||
import six
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
@ -207,6 +205,18 @@ class GraphVisitor(object):
|
||||
node: self.init_state(node) for node in self.graph.index.values()
|
||||
}
|
||||
|
||||
def can_ignore(self, node):
|
||||
"""Returns True if the node can safely be assumed not to touch variables."""
|
||||
ast_node = node.ast_node
|
||||
if anno.hasanno(ast_node, anno.Basic.SKIP_PROCESSING):
|
||||
return True
|
||||
if six.PY2:
|
||||
if (isinstance(ast_node, gast.Name) and
|
||||
ast_node.id in ('None', 'True', 'False')):
|
||||
return True
|
||||
return isinstance(ast_node,
|
||||
(gast.Break, gast.Continue, gast.Raise, gast.Pass))
|
||||
|
||||
def _visit_internal(self, mode):
|
||||
"""Visits the CFG, depth-first."""
|
||||
assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE)
|
||||
|
@ -131,6 +131,7 @@ KNOWN_STRING_CONSTRUCTOR_ERRORS = (
|
||||
RuntimeError,
|
||||
StopIteration,
|
||||
TypeError,
|
||||
UnboundLocalError,
|
||||
ValueError,
|
||||
)
|
||||
|
||||
|
@ -223,6 +223,7 @@ class ActivityAnalyzer(transformer.Base):
|
||||
|
||||
def __init__(self, context, parent_scope=None):
|
||||
super(ActivityAnalyzer, self).__init__(context)
|
||||
self.allow_skips = False
|
||||
self.scope = Scope(parent_scope, isolated=True)
|
||||
|
||||
# Note: all these flags crucially rely on the respective nodes are
|
||||
@ -327,8 +328,21 @@ class ActivityAnalyzer(transformer.Base):
|
||||
return self._process_statement(node)
|
||||
|
||||
def visit_Global(self, node):
|
||||
self._enter_scope(False)
|
||||
for name in node.names:
|
||||
self.scope.globals.add(qual_names.QN(name))
|
||||
qn = qual_names.QN(name)
|
||||
self.scope.read.add(qn)
|
||||
self.scope.globals.add(qn)
|
||||
self._exit_and_record_scope(node)
|
||||
return node
|
||||
|
||||
def visit_Nonlocal(self, node):
|
||||
self._enter_scope(False)
|
||||
for name in node.names:
|
||||
qn = qual_names.QN(name)
|
||||
self.scope.read.add(qn)
|
||||
self.scope.bound.add(qn)
|
||||
self._exit_and_record_scope(node)
|
||||
return node
|
||||
|
||||
def visit_Expr(self, node):
|
||||
|
@ -43,7 +43,10 @@ class ActivityAnalyzerTest(activity_test.ActivityAnalyzerTestBase):
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
fn_node = node
|
||||
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(body_scope, ('nonlocal_b', 'c'), ('nonlocal_a',))
|
||||
self.assertScopeIs(
|
||||
body_scope, ('nonlocal_a', 'nonlocal_b', 'c'), ('nonlocal_a',))
|
||||
nonlocal_a_scope = anno.getanno(fn_node.body[0], anno.Static.SCOPE)
|
||||
self.assertScopeIs(nonlocal_a_scope, ('nonlocal_a',), ())
|
||||
|
||||
def test_annotated_assign(self):
|
||||
b = int
|
||||
|
@ -607,9 +607,11 @@ class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
fn_node = node
|
||||
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(body_scope, ('global_b', 'c'), ('global_a',))
|
||||
self.assertScopeIs(body_scope, ('global_a', 'global_b', 'c'), ('global_a',))
|
||||
self.assertSetEqual(body_scope.globals, set(
|
||||
(QN('global_a'), QN('global_b'))))
|
||||
global_a_scope = anno.getanno(fn_node.body[0], anno.Static.SCOPE)
|
||||
self.assertScopeIs(global_a_scope, ('global_a',), ())
|
||||
|
||||
def test_class_definition_basic(self):
|
||||
|
||||
|
@ -67,12 +67,8 @@ class Analyzer(cfg.GraphVisitor):
|
||||
live_in = gen | (live_out - kill)
|
||||
|
||||
else:
|
||||
# Nodes that don't have a scope annotation are assumed not to touch any
|
||||
# symbols.
|
||||
# This Name node below is a literal name, e.g. False
|
||||
assert isinstance(node.ast_node,
|
||||
(gast.Name, gast.Continue, gast.Break, gast.Pass,
|
||||
gast.Global, gast.Nonlocal)), type(node.ast_node)
|
||||
assert self.can_ignore(node), (node.ast_node, node)
|
||||
|
||||
live_out = set()
|
||||
for n in node.next:
|
||||
live_out |= self.in_[n]
|
||||
@ -105,6 +101,7 @@ class WholeTreeAnalyzer(transformer.Base):
|
||||
|
||||
def __init__(self, source_info, graphs):
|
||||
super(WholeTreeAnalyzer, self).__init__(source_info)
|
||||
self.allow_skips = False
|
||||
self.graphs = graphs
|
||||
self.current_analyzer = None
|
||||
self.analyzers = {}
|
||||
|
@ -34,9 +34,7 @@ import gast
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||
|
||||
|
||||
class Definition(object):
|
||||
@ -137,8 +135,12 @@ class Analyzer(cfg.GraphVisitor):
|
||||
# their ids are used in equality checks.
|
||||
if node not in self.gen_map:
|
||||
node_symbols = {}
|
||||
# Every modification receives a definition.
|
||||
for s in node_scope.modified:
|
||||
# Every binding operation (assign, nonlocal, global, etc.) counts as a
|
||||
# definition, with the exception of del, which only deletes without
|
||||
# creating a new variable.
|
||||
newly_defined = ((node_scope.bound | node_scope.globals) -
|
||||
node_scope.deleted)
|
||||
for s in newly_defined:
|
||||
def_ = self._definition_factory()
|
||||
node_symbols[s] = def_
|
||||
# Every param receives a definition. Params are not necessarily
|
||||
@ -153,41 +155,16 @@ class Analyzer(cfg.GraphVisitor):
|
||||
kill = node_scope.modified | node_scope.deleted
|
||||
defs_out = gen | (defs_in - kill)
|
||||
|
||||
elif isinstance(node.ast_node, (gast.Global, gast.Nonlocal)):
|
||||
# Special case for global and nonlocal: they generate a definition,
|
||||
# but are not tracked by activity analysis.
|
||||
if node not in self.gen_map:
|
||||
node_symbols = {}
|
||||
kill = set()
|
||||
for s in node.ast_node.names:
|
||||
qn = qual_names.QN(s)
|
||||
# TODO(mdan): If definitions exist, should we preserve those instead?
|
||||
# Incoming definitions may be present when this is a local function.
|
||||
# In that case, the definitions of the nonlocal symbol from the
|
||||
# enclosing function are available here. See self.extra_in.
|
||||
kill.add(qn)
|
||||
def_ = self._definition_factory()
|
||||
node_symbols[qn] = def_
|
||||
self.gen_map[node] = _NodeState(node_symbols)
|
||||
|
||||
gen = self.gen_map[node]
|
||||
defs_out = gen | (defs_in - kill)
|
||||
|
||||
else:
|
||||
# Nodes that don't have a scope annotation are assumed not to touch any
|
||||
# symbols.
|
||||
# This Name node below is a literal name, e.g. False
|
||||
# This can also happen if activity.py forgot to annotate the node with a
|
||||
# scope object.
|
||||
assert isinstance(node.ast_node,
|
||||
(gast.Name, gast.Break, gast.Continue, gast.Raise,
|
||||
gast.Pass)), (node.ast_node, node)
|
||||
assert self.can_ignore(node), (node.ast_node, node)
|
||||
defs_out = defs_in
|
||||
|
||||
self.in_[node] = defs_in
|
||||
self.out[node] = defs_out
|
||||
|
||||
# TODO(mdan): Move this to the superclass?
|
||||
return prev_defs_out != defs_out
|
||||
|
||||
|
||||
@ -205,6 +182,7 @@ class TreeAnnotator(transformer.Base):
|
||||
|
||||
def __init__(self, source_info, graphs, definition_factory):
|
||||
super(TreeAnnotator, self).__init__(source_info)
|
||||
self.allow_skips = False
|
||||
self.definition_factory = definition_factory
|
||||
self.graphs = graphs
|
||||
self.current_analyzer = None
|
||||
@ -214,28 +192,11 @@ class TreeAnnotator(transformer.Base):
|
||||
parent_analyzer = self.current_analyzer
|
||||
subgraph = self.graphs[node]
|
||||
|
||||
# Preorder tree processing:
|
||||
# 1. if this is a child function, the parent was already analyzed and it
|
||||
# has the proper state value for the subgraph's entry
|
||||
# 2. analyze the current function body
|
||||
# 2. recursively walk the subtree; child functions will be processed
|
||||
analyzer = Analyzer(subgraph, self.definition_factory)
|
||||
if parent_analyzer is not None:
|
||||
# Wire the state between the two subgraphs' analyzers.
|
||||
parent_out_state = parent_analyzer.out[parent_analyzer.graph.index[node]]
|
||||
# Exception: symbols modified in the child function are local to it
|
||||
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
parent_out_state -= body_scope.modified
|
||||
analyzer.extra_in[node.args] = parent_out_state
|
||||
|
||||
# Complete the analysis for the local function and annotate its body.
|
||||
analyzer.visit_forward()
|
||||
|
||||
# Recursively process any remaining subfunctions.
|
||||
self.current_analyzer = analyzer
|
||||
# Note: not visiting name, decorator_list and returns because they don't
|
||||
# apply to this analysis.
|
||||
# TODO(mdan): Should we still process the function name?
|
||||
node.args = self.visit(node.args)
|
||||
node.body = self.visit_block(node.body)
|
||||
self.current_analyzer = parent_analyzer
|
||||
|
@ -78,7 +78,7 @@ class ReachingDefinitionsAnalyzerTest(
|
||||
|
||||
self.assertSameDef(local_body[1].test, local_body[2].value.elts[0])
|
||||
|
||||
self.assertHasDefinedIn(local_body[1], ('a', 'b', 'local_fn'))
|
||||
self.assertHasDefinedIn(local_body[1], ('a', 'b'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -254,7 +254,8 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase):
|
||||
self.assertHasDefs(fn_body[2].value, 2)
|
||||
|
||||
inner_fn_body = fn_body[1].body[1].body
|
||||
self.assertSameDef(inner_fn_body[0].value, def_of_a_in_if)
|
||||
def_of_a_in_foo = inner_fn_body[0].value
|
||||
self.assertHasDefs(def_of_a_in_foo, 0)
|
||||
|
||||
def test_nested_functions_isolation(self):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user