Track lambda functions separately in the CFG, consistent with normal function definitions.
PiperOrigin-RevId: 307828205 Change-Id: I1de54beba38db0213d0e921698d236aed1dcc355
This commit is contained in:
parent
ef4e3be946
commit
126f665f39
@ -689,6 +689,7 @@ class AstToCfg(gast.NodeVisitor):
|
|||||||
|
|
||||||
def _process_exit_statement(
|
def _process_exit_statement(
|
||||||
self, node, exits_nodes_of_type, may_exit_via_except=False):
|
self, node, exits_nodes_of_type, may_exit_via_except=False):
|
||||||
|
self.generic_visit(node)
|
||||||
# Note: this is safe because we process functions separately.
|
# Note: this is safe because we process functions separately.
|
||||||
try_node, guards = self._get_enclosing_finally_scopes(exits_nodes_of_type)
|
try_node, guards = self._get_enclosing_finally_scopes(exits_nodes_of_type)
|
||||||
assert try_node is not None, '{} that is not enclosed by any of {}'.format(
|
assert try_node is not None, '{} that is not enclosed by any of {}'.format(
|
||||||
@ -737,11 +738,9 @@ class AstToCfg(gast.NodeVisitor):
|
|||||||
# TODO(mdan): Track the CFG local to the class definition as well?
|
# TODO(mdan): Track the CFG local to the class definition as well?
|
||||||
self.builder = self.builder_stack.pop()
|
self.builder = self.builder_stack.pop()
|
||||||
|
|
||||||
def visit_FunctionDef(self, node):
|
def _process_function_def(self, node, is_lambda):
|
||||||
# We also keep the FunctionDef node in the CFG. This allows us to determine
|
# The function body is stored in a separate graph, because function
|
||||||
# things like reaching definitions via closure. Note that the function body
|
# definitions have effects very different from function calls.
|
||||||
# will be stored in a separate graph, because function definitions are not
|
|
||||||
# the same as function calls.
|
|
||||||
if self.builder is not None:
|
if self.builder is not None:
|
||||||
self.builder.add_ordinary_node(node)
|
self.builder.add_ordinary_node(node)
|
||||||
|
|
||||||
@ -752,6 +751,9 @@ class AstToCfg(gast.NodeVisitor):
|
|||||||
self.builder.enter_section(node)
|
self.builder.enter_section(node)
|
||||||
|
|
||||||
self._process_basic_statement(node.args)
|
self._process_basic_statement(node.args)
|
||||||
|
if is_lambda:
|
||||||
|
self._process_exit_statement(node.body, (gast.Lambda,))
|
||||||
|
else:
|
||||||
for stmt in node.body:
|
for stmt in node.body:
|
||||||
self.visit(stmt)
|
self.visit(stmt)
|
||||||
|
|
||||||
@ -761,6 +763,12 @@ class AstToCfg(gast.NodeVisitor):
|
|||||||
self.cfgs[node] = self.builder.build()
|
self.cfgs[node] = self.builder.build()
|
||||||
self.builder = self.builder_stack.pop()
|
self.builder = self.builder_stack.pop()
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
self._process_function_def(node, is_lambda=False)
|
||||||
|
|
||||||
|
def visit_Lambda(self, node):
|
||||||
|
self._process_function_def(node, is_lambda=True)
|
||||||
|
|
||||||
def visit_Return(self, node):
|
def visit_Return(self, node):
|
||||||
self._process_exit_statement(node, (gast.FunctionDef,))
|
self._process_exit_statement(node, (gast.FunctionDef,))
|
||||||
|
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import gast
|
||||||
|
|
||||||
from tensorflow.python.autograph.pyct import cfg
|
from tensorflow.python.autograph.pyct import cfg
|
||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -1030,16 +1032,57 @@ class AstToCfgTest(test.TestCase):
|
|||||||
a = lambda b: a + b
|
a = lambda b: a + b
|
||||||
return a
|
return a
|
||||||
|
|
||||||
graph, = self._build_cfg(test_fn).values()
|
graphs = self._build_cfg(test_fn)
|
||||||
|
for k, v in graphs.items():
|
||||||
|
if isinstance(k, gast.Lambda):
|
||||||
|
lam_graph = v
|
||||||
|
else:
|
||||||
|
fn_graph = v
|
||||||
|
|
||||||
self.assertGraphMatches(
|
self.assertGraphMatches(
|
||||||
graph,
|
fn_graph,
|
||||||
(
|
(
|
||||||
('a', 'a = (lambda b: (a + b))', 'return a'),
|
('a', '(lambda b: (a + b))', 'a = (lambda b: (a + b))'),
|
||||||
|
('(lambda b: (a + b))', 'a = (lambda b: (a + b))', 'return a'),
|
||||||
('a = (lambda b: (a + b))', 'return a', None),
|
('a = (lambda b: (a + b))', 'return a', None),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertGraphEnds(graph, 'a', ('return a',))
|
self.assertGraphEnds(fn_graph, 'a', ('return a',))
|
||||||
|
self.assertGraphMatches(
|
||||||
|
lam_graph,
|
||||||
|
(
|
||||||
|
('b', '(a + b)', None),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertGraphEnds(lam_graph, 'b', ('(a + b)',))
|
||||||
|
|
||||||
|
def test_lambda_in_return(self):
|
||||||
|
|
||||||
|
def test_fn(a):
|
||||||
|
return lambda b: a + b
|
||||||
|
|
||||||
|
graphs = self._build_cfg(test_fn)
|
||||||
|
for k, v in graphs.items():
|
||||||
|
if isinstance(k, gast.Lambda):
|
||||||
|
lam_graph = v
|
||||||
|
else:
|
||||||
|
fn_graph = v
|
||||||
|
|
||||||
|
self.assertGraphMatches(
|
||||||
|
fn_graph,
|
||||||
|
(
|
||||||
|
('a', '(lambda b: (a + b))', 'return (lambda b: (a + b))'),
|
||||||
|
('(lambda b: (a + b))', 'return (lambda b: (a + b))', None),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertGraphEnds(fn_graph, 'a', ('return (lambda b: (a + b))',))
|
||||||
|
self.assertGraphMatches(
|
||||||
|
lam_graph,
|
||||||
|
(
|
||||||
|
('b', '(a + b)', None),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertGraphEnds(lam_graph, 'b', ('(a + b)',))
|
||||||
|
|
||||||
def test_pass(self):
|
def test_pass(self):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user