diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py index 4f09c9db954..2a5ebb2d2b4 100644 --- a/tensorflow/python/autograph/pyct/cfg.py +++ b/tensorflow/python/autograph/pyct/cfg.py @@ -689,6 +689,7 @@ class AstToCfg(gast.NodeVisitor): def _process_exit_statement( self, node, exits_nodes_of_type, may_exit_via_except=False): + self.generic_visit(node) # Note: this is safe because we process functions separately. try_node, guards = self._get_enclosing_finally_scopes(exits_nodes_of_type) assert try_node is not None, '{} that is not enclosed by any of {}'.format( @@ -737,11 +738,9 @@ class AstToCfg(gast.NodeVisitor): # TODO(mdan): Track the CFG local to the class definition as well? self.builder = self.builder_stack.pop() - def visit_FunctionDef(self, node): - # We also keep the FunctionDef node in the CFG. This allows us to determine - # things like reaching definitions via closure. Note that the function body - # will be stored in a separate graph, because function definitions are not - # the same as function calls. + def _process_function_def(self, node, is_lambda): + # The function body is stored in a separate graph, because function + # definitions have effects very different from function calls. if self.builder is not None: self.builder.add_ordinary_node(node) @@ -752,8 +751,11 @@ class AstToCfg(gast.NodeVisitor): self.builder.enter_section(node) self._process_basic_statement(node.args) - for stmt in node.body: - self.visit(stmt) + if is_lambda: + self._process_exit_statement(node.body, (gast.Lambda,)) + else: + for stmt in node.body: + self.visit(stmt) self.builder.exit_section(node) self._exit_lexical_scope(node) @@ -761,6 +763,12 @@ class AstToCfg(gast.NodeVisitor): self.cfgs[node] = self.builder.build() self.builder = self.builder_stack.pop() + def visit_FunctionDef(self, node): + self._process_function_def(node, is_lambda=False) + + def visit_Lambda(self, node): + self._process_function_def(node, is_lambda=True) + def visit_Return(self, node): self._process_exit_statement(node, (gast.FunctionDef,)) diff --git a/tensorflow/python/autograph/pyct/cfg_test.py b/tensorflow/python/autograph/pyct/cfg_test.py index d0b88c84a7f..7995555f6ef 100644 --- a/tensorflow/python/autograph/pyct/cfg_test.py +++ b/tensorflow/python/autograph/pyct/cfg_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gast + from tensorflow.python.autograph.pyct import cfg from tensorflow.python.autograph.pyct import parser from tensorflow.python.platform import test @@ -1030,16 +1032,57 @@ class AstToCfgTest(test.TestCase): a = lambda b: a + b return a - graph, = self._build_cfg(test_fn).values() + graphs = self._build_cfg(test_fn) + for k, v in graphs.items(): + if isinstance(k, gast.Lambda): + lam_graph = v + else: + fn_graph = v self.assertGraphMatches( - graph, + fn_graph, ( - ('a', 'a = (lambda b: (a + b))', 'return a'), + ('a', '(lambda b: (a + b))', 'a = (lambda b: (a + b))'), + ('(lambda b: (a + b))', 'a = (lambda b: (a + b))', 'return a'), ('a = (lambda b: (a + b))', 'return a', None), ), ) - self.assertGraphEnds(graph, 'a', ('return a',)) + self.assertGraphEnds(fn_graph, 'a', ('return a',)) + self.assertGraphMatches( + lam_graph, + ( + ('b', '(a + b)', None), + ), + ) + self.assertGraphEnds(lam_graph, 'b', ('(a + b)',)) + + def test_lambda_in_return(self): + + def test_fn(a): + return lambda b: a + b + + graphs = self._build_cfg(test_fn) + for k, v in graphs.items(): + if isinstance(k, gast.Lambda): + lam_graph = v + else: + fn_graph = v + + self.assertGraphMatches( + fn_graph, + ( + ('a', '(lambda b: (a + b))', 'return (lambda b: (a + b))'), + ('(lambda b: (a + b))', 'return (lambda b: (a + b))', None), + ), + ) + self.assertGraphEnds(fn_graph, 'a', ('return (lambda b: (a + b))',)) + self.assertGraphMatches( + lam_graph, + ( + ('b', '(a + b)', None), + ), + ) + self.assertGraphEnds(lam_graph, 'b', ('(a + b)',)) def test_pass(self):