Track lambda functions separately in the CFG, consistent with normal function definitions.

PiperOrigin-RevId: 307828205
Change-Id: I1de54beba38db0213d0e921698d236aed1dcc355
This commit is contained in:
Dan Moldovan 2020-04-22 09:04:27 -07:00 committed by TensorFlower Gardener
parent ef4e3be946
commit 126f665f39
2 changed files with 62 additions and 11 deletions

View File

@ -689,6 +689,7 @@ class AstToCfg(gast.NodeVisitor):
def _process_exit_statement(
self, node, exits_nodes_of_type, may_exit_via_except=False):
self.generic_visit(node)
# Note: this is safe because we process functions separately.
try_node, guards = self._get_enclosing_finally_scopes(exits_nodes_of_type)
assert try_node is not None, '{} that is not enclosed by any of {}'.format(
@ -737,11 +738,9 @@ class AstToCfg(gast.NodeVisitor):
# TODO(mdan): Track the CFG local to the class definition as well?
self.builder = self.builder_stack.pop()
def visit_FunctionDef(self, node):
# We also keep the FunctionDef node in the CFG. This allows us to determine
# things like reaching definitions via closure. Note that the function body
# will be stored in a separate graph, because function definitions are not
# the same as function calls.
def _process_function_def(self, node, is_lambda):
# The function body is stored in a separate graph, because function
# definitions have effects very different from function calls.
if self.builder is not None:
self.builder.add_ordinary_node(node)
@ -752,8 +751,11 @@ class AstToCfg(gast.NodeVisitor):
self.builder.enter_section(node)
self._process_basic_statement(node.args)
for stmt in node.body:
self.visit(stmt)
if is_lambda:
self._process_exit_statement(node.body, (gast.Lambda,))
else:
for stmt in node.body:
self.visit(stmt)
self.builder.exit_section(node)
self._exit_lexical_scope(node)
@ -761,6 +763,12 @@ class AstToCfg(gast.NodeVisitor):
self.cfgs[node] = self.builder.build()
self.builder = self.builder_stack.pop()
def visit_FunctionDef(self, node):
self._process_function_def(node, is_lambda=False)
def visit_Lambda(self, node):
self._process_function_def(node, is_lambda=True)
def visit_Return(self, node):
self._process_exit_statement(node, (gast.FunctionDef,))

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gast
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
@ -1030,16 +1032,57 @@ class AstToCfgTest(test.TestCase):
a = lambda b: a + b
return a
graph, = self._build_cfg(test_fn).values()
graphs = self._build_cfg(test_fn)
for k, v in graphs.items():
if isinstance(k, gast.Lambda):
lam_graph = v
else:
fn_graph = v
self.assertGraphMatches(
graph,
fn_graph,
(
('a', 'a = (lambda b: (a + b))', 'return a'),
('a', '(lambda b: (a + b))', 'a = (lambda b: (a + b))'),
('(lambda b: (a + b))', 'a = (lambda b: (a + b))', 'return a'),
('a = (lambda b: (a + b))', 'return a', None),
),
)
self.assertGraphEnds(graph, 'a', ('return a',))
self.assertGraphEnds(fn_graph, 'a', ('return a',))
self.assertGraphMatches(
lam_graph,
(
('b', '(a + b)', None),
),
)
self.assertGraphEnds(lam_graph, 'b', ('(a + b)',))
def test_lambda_in_return(self):
def test_fn(a):
return lambda b: a + b
graphs = self._build_cfg(test_fn)
for k, v in graphs.items():
if isinstance(k, gast.Lambda):
lam_graph = v
else:
fn_graph = v
self.assertGraphMatches(
fn_graph,
(
('a', '(lambda b: (a + b))', 'return (lambda b: (a + b))'),
('(lambda b: (a + b))', 'return (lambda b: (a + b))', None),
),
)
self.assertGraphEnds(fn_graph, 'a', ('return (lambda b: (a + b))',))
self.assertGraphMatches(
lam_graph,
(
('b', '(a + b)', None),
),
)
self.assertGraphEnds(lam_graph, 'b', ('(a + b)',))
def test_pass(self):