Improve the completeness of the CFG by drawing edges from raise statements to all enclosing except blocks.

PiperOrigin-RevId: 295837876
Change-Id: I34e6ad8eb50e984fd526948d66ceaaf27c3b453a
This commit is contained in:
Dan Moldovan 2020-02-18 15:33:42 -08:00 committed by TensorFlower Gardener
parent f8822b0a55
commit 9189ce99fc
2 changed files with 167 additions and 21 deletions

View File

@ -21,13 +21,14 @@ a corresponding CFG counterpart.
Once built, the CFG itself is immutable, but the values it holds need not be;
they are usually annotated with information extracted by walking the graph.
Note: the CFG tries to include all code paths that MAY be taken, with the
follwing exceptions:
Tip: Use `Graph.as_dot` to visualize the CFG using any DOT viewer.
Note: the CFG tries to include all code paths that MAY be taken, with a single
notable exception:
* function calls do not generate edges corresponding to exceptions they may
raise (i.e. a function call in the middle of a block does not exit or jump
to an except block)
* raise never generates an edge to an except block
(TODO:mdan): Remove this last bullet.
raise (i.e. a function call in the middle of a block does not return or jump
to any except or finally block)
TODO(mdan): Consider adding the edges above. They'd only add ~O(n) edges.
"""
# TODO(mdan): The notion of 'statements' below is inaccurate.
@ -309,6 +310,9 @@ class GraphBuilder(object):
# Continue jumps keyed by the section they affect.
self.continues = {}
# Raise jumps keyed by the except section guarding them.
self.raises = {}
# The entry of conditional sections, keyed by the section.
self.cond_entry = {}
# Lists of leaf nodes corresponding to each branch in the section.
@ -429,9 +433,12 @@ class GraphBuilder(object):
section_id: Hashable, the node for which ast_node should be considered
to be an exit node
guards: Tuple[ast.AST, ...], the finally sections that guard ast_node
Returns:
Node
"""
node = self._add_jump_node(ast_node, guards)
self.exits[section_id].add(node)
return node
def add_continue_node(self, ast_node, section_id, guards):
"""Grows the graph by adding a reentry node.
@ -447,6 +454,21 @@ class GraphBuilder(object):
node = self._add_jump_node(ast_node, guards)
self.continues[section_id].add(node)
def connect_raise_node(self, node, except_guards):
"""Adds extra connection between a raise node and containing except guards.
The node is a graph node, not an ast node.
Args:
node: Node
except_guards: Tuple[ast.AST, ...], the except sections that guard node
"""
for guard in except_guards:
if guard in self.raises:
self.raises[guard].append(node)
else:
self.raises[guard] = [node]
def enter_section(self, section_id):
"""Enters a regular section.
@ -537,6 +559,11 @@ class GraphBuilder(object):
del self.cond_entry[section_id]
del self.cond_leaves[section_id]
def enter_except_section(self, section_id):
"""Enters an except section."""
if section_id in self.raises:
self.leaves.update(self.raises[section_id])
def enter_finally_section(self, section_id):
"""Enters a finally section."""
# TODO(mdan): This, not the caller, should track the active sections.
@ -636,18 +663,31 @@ class AstToCfg(gast.NodeVisitor):
return node, included
return None, included
def _get_enclosing_except_scopes(self, stop_at):
included = []
for node in reversed(self.lexical_scopes):
if isinstance(node, gast.Try) and node.handlers:
included.extend(node.handlers)
if isinstance(node, stop_at):
break
return included
def _process_basic_statement(self, node):
self.generic_visit(node)
self.builder.add_ordinary_node(node)
def _process_exit_statement(self, node, *exits_nodes_of_type):
def _process_exit_statement(
self, node, exits_nodes_of_type, may_exit_via_except=False):
# Note: this is safe because we process functions separately.
try_node, guards = self._get_enclosing_finally_scopes(
tuple(exits_nodes_of_type))
if try_node is None:
raise ValueError(
'%s that is not enclosed by any of %s' % (node, exits_nodes_of_type))
self.builder.add_exit_node(node, try_node, guards)
try_node, guards = self._get_enclosing_finally_scopes(exits_nodes_of_type)
assert try_node is not None, '{} that is not enclosed by any of {}'.format(
node, exits_nodes_of_type)
node = self.builder.add_exit_node(node, try_node, guards)
if may_exit_via_except:
except_guards = self._get_enclosing_except_scopes(exits_nodes_of_type)
self.builder.connect_raise_node(node, except_guards)
def _process_continue_statement(self, node, *loops_to_nodes_of_type):
# Note: this is safe because we process functions separately.
@ -711,7 +751,7 @@ class AstToCfg(gast.NodeVisitor):
self.builder = self.builder_stack.pop()
def visit_Return(self, node):
self._process_exit_statement(node, gast.FunctionDef)
self._process_exit_statement(node, (gast.FunctionDef,))
def visit_Expr(self, node):
self._process_basic_statement(node)
@ -738,7 +778,8 @@ class AstToCfg(gast.NodeVisitor):
self._process_basic_statement(node)
def visit_Raise(self, node):
self._process_exit_statement(node, gast.FunctionDef)
self._process_exit_statement(
node, (gast.FunctionDef,), may_exit_via_except=True)
self.builder.errors.add(node)
def visit_Assert(self, node):
@ -818,13 +859,14 @@ class AstToCfg(gast.NodeVisitor):
self.builder.end_statement(node)
def visit_Break(self, node):
self._process_exit_statement(node, gast.While, gast.For)
self._process_exit_statement(node, (gast.While, gast.For,))
def visit_Continue(self, node):
self._process_continue_statement(node, gast.While, gast.For)
self._process_continue_statement(node, (gast.While, gast.For,))
def visit_ExceptHandler(self, node):
self.builder.begin_statement(node)
self.builder.enter_except_section(node)
if node.type is not None:
self.visit(node.type)

View File

@ -1309,21 +1309,125 @@ class AstToCfgTest(test.TestCase):
graph,
(
('a, b', '(a > 0)', ('raise b', 'return 0')),
('(a > 0)', 'raise b', None),
('(a > 0)', 'raise b', 'return 1'),
('(a > 0)', 'return 0', None),
(None, 'return 1', None),
('raise b', 'return 1', None),
),
)
self.assertStatementEdges(
graph,
(
('a, b', 'Try:2', None),
('a, b', 'If:3', None),
(None, 'ExceptHandler:7', None),
('a, b', 'If:3', 'return 1'),
('raise b', 'ExceptHandler:7', None),
),
)
self.assertGraphEnds(graph, 'a, b', ('return 0', 'return 1', 'raise b'))
def test_raise_exits(self):
def test_fn(a, b):
raise b
return a # pylint:disable=unreachable
graph, = self._build_cfg(test_fn).values()
self.assertGraphMatches(
graph,
(
('a, b', 'raise b', None),
(None, 'return a', None),
),
)
self.assertGraphEnds(graph, 'a, b', ('raise b', 'return a'))
def test_raise_triggers_enclosing_finally(self):
def test_fn(a):
try:
try:
raise a
return 1 # pylint:disable=unreachable
finally:
b = 1
return 2
finally:
b = 2
return b
graph, = self._build_cfg(test_fn).values()
self.assertGraphMatches(
graph,
(
('a', 'raise a', 'b = 1'),
(('raise a', 'return 1'), 'b = 1', 'b = 2'),
(None, 'return 1', 'b = 1'),
(None, 'return 2', 'b = 2'),
(('return 2', 'b = 1'), 'b = 2', None),
(None, 'return b', None),
),
)
self.assertGraphEnds(
graph, 'a', ('return b', 'b = 2'))
def test_raise_adds_finally_sortcuts(self):
def test_fn(a):
try:
try:
if a > 0:
raise a
c = 1
finally:
b = 1
c = 2
finally:
b = 2
return b, c
graph, = self._build_cfg(test_fn).values()
self.assertGraphMatches(
graph,
(
('a', '(a > 0)', ('raise a', 'c = 1')),
('(a > 0)', 'raise a', 'b = 1'),
('(a > 0)', 'c = 1', 'b = 1'),
(('raise a', 'c = 1'), 'b = 1', ('c = 2', 'b = 2')),
('b = 1', 'c = 2', 'b = 2'),
(('b = 1', 'c = 2'), 'b = 2', 'return (b, c)'),
('b = 2', 'return (b, c)', None),
),
)
self.assertGraphEnds(
graph, 'a', ('return (b, c)', 'b = 2'))
def test_raise_exits_via_except(self):
def test_fn(a, b):
try:
raise b
except a:
c = 1
except b:
c = 2
finally:
c += 3
graph, = self._build_cfg(test_fn).values()
self.assertGraphMatches(
graph,
(
('a, b', 'raise b', ('c = 1', 'c = 2', 'c += 3')),
('raise b', 'c = 1', 'c += 3'),
('raise b', 'c = 2', 'c += 3'),
(('raise b', 'c = 1', 'c = 2'), 'c += 3', None),
),
)
self.assertGraphEnds(graph, 'a, b', ('c += 3',))
def test_list_comprehension(self):
def test_fn(a):