Fix CFG to correctly account for lambda functions that appear in while and for loops.

PiperOrigin-RevId: 307956866
Change-Id: Ifed545cf034c306d6655af0a3ec48eaa0c9b68d3
This commit is contained in:
Dan Moldovan 2020-04-22 19:39:56 -07:00 committed by TensorFlower Gardener
parent 7bd0df753f
commit 8d72dbe4ea
2 changed files with 62 additions and 0 deletions
tensorflow/python/autograph/pyct

View File

@ -842,6 +842,7 @@ class AstToCfg(gast.NodeVisitor):
self.builder.enter_section(node)
self.generic_visit(node.test)
self.builder.enter_loop_section(node, node.test)
for stmt in node.body:
self.visit(stmt)
@ -867,6 +868,7 @@ class AstToCfg(gast.NodeVisitor):
# Note: Strictly speaking, this should be node.target + node.iter.
# However, the activity analysis accounts for this inconsistency,
# so dataflow analysis produces the correct values.
self.generic_visit(node.iter)
self.builder.enter_loop_section(node, node.iter)
# Also include the "extra loop test" annotation, to capture things like the
# control variable for return and break in for loops.

View File

@ -1084,6 +1084,66 @@ class AstToCfgTest(test.TestCase):
)
self.assertGraphEnds(lam_graph, 'b', ('(a + b)',))
def test_lambda_in_while_loop_test(self):
def test_fn(a):
while (lambda b: a + b)(a):
pass
graphs = self._build_cfg(test_fn)
for k, v in graphs.items():
if isinstance(k, gast.Lambda):
lam_graph = v
else:
fn_graph = v
self.assertGraphMatches(
fn_graph,
(
('a', '(lambda b: (a + b))', '(lambda b: (a + b))(a)'),
(('(lambda b: (a + b))', 'pass'), '(lambda b: (a + b))(a)', 'pass'),
('(lambda b: (a + b))(a)', 'pass', '(lambda b: (a + b))(a)'),
),
)
self.assertGraphEnds(fn_graph, 'a', ('(lambda b: (a + b))(a)',))
self.assertGraphMatches(
lam_graph,
(
('b', '(a + b)', None),
),
)
self.assertGraphEnds(lam_graph, 'b', ('(a + b)',))
def test_lambda_in_for_loop_test(self):
def test_fn(a):
for _ in (lambda b: a + b)(a):
pass
graphs = self._build_cfg(test_fn)
for k, v in graphs.items():
if isinstance(k, gast.Lambda):
lam_graph = v
else:
fn_graph = v
self.assertGraphMatches(
fn_graph,
(
('a', '(lambda b: (a + b))', '(lambda b: (a + b))(a)'),
(('(lambda b: (a + b))', 'pass'), '(lambda b: (a + b))(a)', 'pass'),
('(lambda b: (a + b))(a)', 'pass', '(lambda b: (a + b))(a)'),
),
)
self.assertGraphEnds(fn_graph, 'a', ('(lambda b: (a + b))(a)',))
self.assertGraphMatches(
lam_graph,
(
('b', '(a + b)', None),
),
)
self.assertGraphEnds(lam_graph, 'b', ('(a + b)',))
def test_pass(self):
def test_fn(a): # pylint:disable=unused-argument