Fix CFG to correctly account for lambda functions that appear in while and for loops.
PiperOrigin-RevId: 307956866 Change-Id: Ifed545cf034c306d6655af0a3ec48eaa0c9b68d3
This commit is contained in:
parent
7bd0df753f
commit
8d72dbe4ea
tensorflow/python/autograph/pyct
@ -842,6 +842,7 @@ class AstToCfg(gast.NodeVisitor):
|
||||
|
||||
self.builder.enter_section(node)
|
||||
|
||||
self.generic_visit(node.test)
|
||||
self.builder.enter_loop_section(node, node.test)
|
||||
for stmt in node.body:
|
||||
self.visit(stmt)
|
||||
@ -867,6 +868,7 @@ class AstToCfg(gast.NodeVisitor):
|
||||
# Note: Strictly speaking, this should be node.target + node.iter.
|
||||
# However, the activity analysis accounts for this inconsistency,
|
||||
# so dataflow analysis produces the correct values.
|
||||
self.generic_visit(node.iter)
|
||||
self.builder.enter_loop_section(node, node.iter)
|
||||
# Also include the "extra loop test" annotation, to capture things like the
|
||||
# control variable for return and break in for loops.
|
||||
|
@ -1084,6 +1084,66 @@ class AstToCfgTest(test.TestCase):
|
||||
)
|
||||
self.assertGraphEnds(lam_graph, 'b', ('(a + b)',))
|
||||
|
||||
def test_lambda_in_while_loop_test(self):
|
||||
|
||||
def test_fn(a):
|
||||
while (lambda b: a + b)(a):
|
||||
pass
|
||||
|
||||
graphs = self._build_cfg(test_fn)
|
||||
for k, v in graphs.items():
|
||||
if isinstance(k, gast.Lambda):
|
||||
lam_graph = v
|
||||
else:
|
||||
fn_graph = v
|
||||
|
||||
self.assertGraphMatches(
|
||||
fn_graph,
|
||||
(
|
||||
('a', '(lambda b: (a + b))', '(lambda b: (a + b))(a)'),
|
||||
(('(lambda b: (a + b))', 'pass'), '(lambda b: (a + b))(a)', 'pass'),
|
||||
('(lambda b: (a + b))(a)', 'pass', '(lambda b: (a + b))(a)'),
|
||||
),
|
||||
)
|
||||
self.assertGraphEnds(fn_graph, 'a', ('(lambda b: (a + b))(a)',))
|
||||
self.assertGraphMatches(
|
||||
lam_graph,
|
||||
(
|
||||
('b', '(a + b)', None),
|
||||
),
|
||||
)
|
||||
self.assertGraphEnds(lam_graph, 'b', ('(a + b)',))
|
||||
|
||||
def test_lambda_in_for_loop_test(self):
|
||||
|
||||
def test_fn(a):
|
||||
for _ in (lambda b: a + b)(a):
|
||||
pass
|
||||
|
||||
graphs = self._build_cfg(test_fn)
|
||||
for k, v in graphs.items():
|
||||
if isinstance(k, gast.Lambda):
|
||||
lam_graph = v
|
||||
else:
|
||||
fn_graph = v
|
||||
|
||||
self.assertGraphMatches(
|
||||
fn_graph,
|
||||
(
|
||||
('a', '(lambda b: (a + b))', '(lambda b: (a + b))(a)'),
|
||||
(('(lambda b: (a + b))', 'pass'), '(lambda b: (a + b))(a)', 'pass'),
|
||||
('(lambda b: (a + b))(a)', 'pass', '(lambda b: (a + b))(a)'),
|
||||
),
|
||||
)
|
||||
self.assertGraphEnds(fn_graph, 'a', ('(lambda b: (a + b))(a)',))
|
||||
self.assertGraphMatches(
|
||||
lam_graph,
|
||||
(
|
||||
('b', '(a + b)', None),
|
||||
),
|
||||
)
|
||||
self.assertGraphEnds(lam_graph, 'b', ('(a + b)',))
|
||||
|
||||
def test_pass(self):
|
||||
|
||||
def test_fn(a): # pylint:disable=unused-argument
|
||||
|
Loading…
Reference in New Issue
Block a user