From 3f8dcd3e288f213001eace4aea0f22cfb1b65946 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 6 Feb 2019 12:05:05 -0800 Subject: [PATCH] Allow exceptions in code that will be staged. Consider all exceptions to be exiting the CFG, with no explicit support for exception-based control-flow. It may incidentally work to use exception-based control flow in code that is never staged. PiperOrigin-RevId: 232719435 --- .../converters/side_effect_guards.py | 4 +++ tensorflow/python/autograph/pyct/cfg.py | 36 ++++++++----------- .../pyct/static_analysis/liveness.py | 4 +++ .../static_analysis/reaching_definitions.py | 7 +++- 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/tensorflow/python/autograph/converters/side_effect_guards.py b/tensorflow/python/autograph/converters/side_effect_guards.py index d7c0951fcc6..7e556d95139 100644 --- a/tensorflow/python/autograph/converters/side_effect_guards.py +++ b/tensorflow/python/autograph/converters/side_effect_guards.py @@ -125,6 +125,10 @@ class SideEffectGuardTransformer(converter.Base): node.orelse = self._visit_and_reindent(node.orelse) return node + # TODO(b/123995141) Remove once ExceptionHandlers are in the CFG + def visit_ExceptHandler(self, node): + return node + def visit_Expr(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py index fdfcd4dcc15..0cedfa84ab3 100644 --- a/tensorflow/python/autograph/pyct/cfg.py +++ b/tensorflow/python/autograph/pyct/cfg.py @@ -393,6 +393,8 @@ class GraphBuilder(object): def _connect_jump_to_finally_sections(self, node): """Connects a jump node to the finally sections protecting it.""" cursor = set((node,)) + if node not in self.finally_sections: + return cursor for guard_section_id in self.finally_sections[node]: guard_begin, guard_ends = self.finally_section_subgraphs[guard_section_id] self._connect_nodes(cursor, guard_begin) @@ -620,10 +622,10 @@ class AstToCfg(gast.NodeVisitor): leaving_node = self.lexical_scopes.pop() assert node == leaving_node - def _get_enclosing_scopes(self, include, stop_at): + def _get_enclosing_finally_scopes(self, stop_at): included = [] for node in reversed(self.lexical_scopes): - if isinstance(node, include): + if isinstance(node, gast.Try) and node.finalbody: included.append(node) if isinstance(node, stop_at): return node, included @@ -635,10 +637,8 @@ class AstToCfg(gast.NodeVisitor): def _process_exit_statement(self, node, *exits_nodes_of_type): # Note: this is safe because we process functions separately. - try_node, guards = self._get_enclosing_scopes( - include=(gast.Try,), - stop_at=tuple(exits_nodes_of_type), - ) + try_node, guards = self._get_enclosing_finally_scopes( + tuple(exits_nodes_of_type)) if try_node is None: raise ValueError( '%s that is not enclosed by any of %s' % (node, exits_nodes_of_type)) @@ -646,10 +646,8 @@ class AstToCfg(gast.NodeVisitor): def _process_continue_statement(self, node, *loops_to_nodes_of_type): # Note: this is safe because we process functions separately. - try_node, guards = self._get_enclosing_scopes( - include=(gast.Try,), - stop_at=tuple(loops_to_nodes_of_type), - ) + try_node, guards = self._get_enclosing_finally_scopes( + tuple(loops_to_nodes_of_type)) if try_node is None: raise ValueError('%s that is not enclosed by any of %s' % (node, loops_to_nodes_of_type)) @@ -698,10 +696,7 @@ class AstToCfg(gast.NodeVisitor): self._process_basic_statement(node) def visit_Raise(self, node): - try_node, guards = self._get_enclosing_scopes( - include=(gast.Try,), - stop_at=(gast.FunctionDef,), - ) + try_node, guards = self._get_enclosing_finally_scopes((gast.FunctionDef,)) if try_node is None: raise ValueError('%s that is not enclosed by any FunctionDef' % node) self.builder.add_error_node(node, guards) @@ -797,16 +792,13 @@ class AstToCfg(gast.NodeVisitor): for stmt in node.orelse: self.visit(stmt) - if node.handlers: - # TODO(mdan): Should we still support bare try/except? Might be confusing. - raise NotImplementedError('exceptions are not yet supported') - self._exit_lexical_scope(node) - self.builder.enter_finally_section(node) - for stmt in node.finalbody: - self.visit(stmt) - self.builder.exit_finally_section(node) + if node.finalbody: + self.builder.enter_finally_section(node) + for stmt in node.finalbody: + self.visit(stmt) + self.builder.exit_finally_section(node) def visit_With(self, node): # TODO(mdan): Mark the context manager's exit call as exit guard. diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness.py b/tensorflow/python/autograph/pyct/static_analysis/liveness.py index f8b8d7fa77c..691b786db0d 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/liveness.py +++ b/tensorflow/python/autograph/pyct/static_analysis/liveness.py @@ -219,6 +219,10 @@ class Annotator(transformer.Base): frozenset(self.current_analyzer.out[cfg_node])) return node + def visit_ExceptHandler(self, node): + # TODO(b/123995141) Add Exception Handlers to the CFG + return node + def resolve(node, source_info, graphs): """Resolves the live symbols at the exit of control flow statements. diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py index d1587d81780..6f0f09ee881 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py +++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py @@ -223,6 +223,10 @@ class TreeAnnotator(transformer.Base): def visit_global(self, node): raise NotImplementedError() + def visit_ExceptHandler(self, node): + # TODO(b/123995141) Add Exception Handlers to the CFG + return node + def visit_Name(self, node): if self.current_analyzer is None: # Names may appear outside function defs - for example in class @@ -232,7 +236,8 @@ class TreeAnnotator(transformer.Base): analyzer = self.current_analyzer cfg_node = self.current_cfg_node - assert cfg_node is not None, 'name node outside of any statement?' + assert cfg_node is not None, ('name node, %s, outside of any statement?' + % node.id) qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Load):