Track lambda functions separately in the CFG, consistent with normal function definitions.
PiperOrigin-RevId: 307828205 Change-Id: I1de54beba38db0213d0e921698d236aed1dcc355
This commit is contained in:
parent
ef4e3be946
commit
126f665f39
@ -689,6 +689,7 @@ class AstToCfg(gast.NodeVisitor):
|
||||
|
||||
def _process_exit_statement(
|
||||
self, node, exits_nodes_of_type, may_exit_via_except=False):
|
||||
self.generic_visit(node)
|
||||
# Note: this is safe because we process functions separately.
|
||||
try_node, guards = self._get_enclosing_finally_scopes(exits_nodes_of_type)
|
||||
assert try_node is not None, '{} that is not enclosed by any of {}'.format(
|
||||
@ -737,11 +738,9 @@ class AstToCfg(gast.NodeVisitor):
|
||||
# TODO(mdan): Track the CFG local to the class definition as well?
|
||||
self.builder = self.builder_stack.pop()
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
# We also keep the FunctionDef node in the CFG. This allows us to determine
|
||||
# things like reaching definitions via closure. Note that the function body
|
||||
# will be stored in a separate graph, because function definitions are not
|
||||
# the same as function calls.
|
||||
def _process_function_def(self, node, is_lambda):
|
||||
# The function body is stored in a separate graph, because function
|
||||
# definitions have effects very different from function calls.
|
||||
if self.builder is not None:
|
||||
self.builder.add_ordinary_node(node)
|
||||
|
||||
@ -752,8 +751,11 @@ class AstToCfg(gast.NodeVisitor):
|
||||
self.builder.enter_section(node)
|
||||
|
||||
self._process_basic_statement(node.args)
|
||||
for stmt in node.body:
|
||||
self.visit(stmt)
|
||||
if is_lambda:
|
||||
self._process_exit_statement(node.body, (gast.Lambda,))
|
||||
else:
|
||||
for stmt in node.body:
|
||||
self.visit(stmt)
|
||||
|
||||
self.builder.exit_section(node)
|
||||
self._exit_lexical_scope(node)
|
||||
@ -761,6 +763,12 @@ class AstToCfg(gast.NodeVisitor):
|
||||
self.cfgs[node] = self.builder.build()
|
||||
self.builder = self.builder_stack.pop()
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
self._process_function_def(node, is_lambda=False)
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
self._process_function_def(node, is_lambda=True)
|
||||
|
||||
def visit_Return(self, node):
|
||||
self._process_exit_statement(node, (gast.FunctionDef,))
|
||||
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.platform import test
|
||||
@ -1030,16 +1032,57 @@ class AstToCfgTest(test.TestCase):
|
||||
a = lambda b: a + b
|
||||
return a
|
||||
|
||||
graph, = self._build_cfg(test_fn).values()
|
||||
graphs = self._build_cfg(test_fn)
|
||||
for k, v in graphs.items():
|
||||
if isinstance(k, gast.Lambda):
|
||||
lam_graph = v
|
||||
else:
|
||||
fn_graph = v
|
||||
|
||||
self.assertGraphMatches(
|
||||
graph,
|
||||
fn_graph,
|
||||
(
|
||||
('a', 'a = (lambda b: (a + b))', 'return a'),
|
||||
('a', '(lambda b: (a + b))', 'a = (lambda b: (a + b))'),
|
||||
('(lambda b: (a + b))', 'a = (lambda b: (a + b))', 'return a'),
|
||||
('a = (lambda b: (a + b))', 'return a', None),
|
||||
),
|
||||
)
|
||||
self.assertGraphEnds(graph, 'a', ('return a',))
|
||||
self.assertGraphEnds(fn_graph, 'a', ('return a',))
|
||||
self.assertGraphMatches(
|
||||
lam_graph,
|
||||
(
|
||||
('b', '(a + b)', None),
|
||||
),
|
||||
)
|
||||
self.assertGraphEnds(lam_graph, 'b', ('(a + b)',))
|
||||
|
||||
def test_lambda_in_return(self):
|
||||
|
||||
def test_fn(a):
|
||||
return lambda b: a + b
|
||||
|
||||
graphs = self._build_cfg(test_fn)
|
||||
for k, v in graphs.items():
|
||||
if isinstance(k, gast.Lambda):
|
||||
lam_graph = v
|
||||
else:
|
||||
fn_graph = v
|
||||
|
||||
self.assertGraphMatches(
|
||||
fn_graph,
|
||||
(
|
||||
('a', '(lambda b: (a + b))', 'return (lambda b: (a + b))'),
|
||||
('(lambda b: (a + b))', 'return (lambda b: (a + b))', None),
|
||||
),
|
||||
)
|
||||
self.assertGraphEnds(fn_graph, 'a', ('return (lambda b: (a + b))',))
|
||||
self.assertGraphMatches(
|
||||
lam_graph,
|
||||
(
|
||||
('b', '(a + b)', None),
|
||||
),
|
||||
)
|
||||
self.assertGraphEnds(lam_graph, 'b', ('(a + b)',))
|
||||
|
||||
def test_pass(self):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user