Allow local class definitions.
PiperOrigin-RevId: 254756131
This commit is contained in:
parent
9a62d2e09e
commit
9365ef2fc3
@ -74,6 +74,8 @@ class Node(object):
|
||||
def __repr__(self):
|
||||
if isinstance(self.ast_node, gast.FunctionDef):
|
||||
return 'def %s' % self.ast_node.name
|
||||
elif isinstance(self.ast_node, gast.ClassDef):
|
||||
return 'class %s' % self.ast_node.name
|
||||
elif isinstance(self.ast_node, gast.withitem):
|
||||
return compiler.ast_to_source(self.ast_node.context_expr).strip()
|
||||
return compiler.ast_to_source(self.ast_node).strip()
|
||||
@ -659,6 +661,34 @@ class AstToCfg(gast.NodeVisitor):
|
||||
(node, loops_to_nodes_of_type))
|
||||
self.builder.add_continue_node(node, try_node, guards)
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
# We also keep the ClassDef node in the CFG, since it technically is a
|
||||
# statement.
|
||||
# For example, this is legal and allows executing user code:
|
||||
#
|
||||
# class Foo(bar()):
|
||||
# pass
|
||||
#
|
||||
# It also has a scope:
|
||||
#
|
||||
# class Bar(object):
|
||||
# a = 1
|
||||
if self.builder is None:
|
||||
self.generic_visit(node)
|
||||
return
|
||||
|
||||
self.builder.add_ordinary_node(node)
|
||||
|
||||
self.builder_stack.append(self.builder)
|
||||
self.builder = GraphBuilder(node)
|
||||
self._enter_lexical_scope(node)
|
||||
|
||||
self._process_basic_statement(node)
|
||||
|
||||
self._exit_lexical_scope(node)
|
||||
# TODO(mdan): Track the CFG local to the class definition as well?
|
||||
self.builder = self.builder_stack.pop()
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
# We also keep the FunctionDef node in the CFG. This allows us to determine
|
||||
# things like reaching definitions via closure. Note that the function body
|
||||
|
@ -1268,6 +1268,40 @@ class AstToCfgTest(test.TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_class_definition_empty(self):
|
||||
|
||||
def test_fn(a, b):
|
||||
class C(a(b)):
|
||||
pass
|
||||
return C
|
||||
|
||||
graph, = self._build_cfg(test_fn).values()
|
||||
|
||||
self.assertGraphMatches(
|
||||
graph,
|
||||
(
|
||||
('a, b', 'class C', 'return C'),
|
||||
('class C', 'return C', None),
|
||||
),
|
||||
)
|
||||
|
||||
def test_class_definition_with_members(self):
|
||||
|
||||
def test_fn(a, b):
|
||||
class C(a(b)):
|
||||
d = 1
|
||||
return C
|
||||
|
||||
graph, = self._build_cfg(test_fn).values()
|
||||
|
||||
self.assertGraphMatches(
|
||||
graph,
|
||||
(
|
||||
('a, b', 'class C', 'return C'),
|
||||
('class C', 'return C', None),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -400,6 +400,22 @@ class ActivityAnalyzer(transformer.Base):
|
||||
def visit_arguments(self, node):
|
||||
return self._process_statement(node)
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
# The ClassDef node itself has a Scope object that tracks the creation
|
||||
# of its name, along with the usage of any decorator accompanying it.
|
||||
self._enter_scope(False)
|
||||
node.decorator_list = self.visit_block(node.decorator_list)
|
||||
self.scope.mark_modified(qual_names.QN(node.name))
|
||||
anno.setanno(node, anno.Static.SCOPE, self.scope)
|
||||
self._exit_scope()
|
||||
|
||||
# A separate Scope tracks the actual class definition.
|
||||
self._enter_scope(True)
|
||||
assert not (self._in_function_def_args or self.state[_Lambda].level)
|
||||
node = self.generic_visit(node)
|
||||
self._exit_scope()
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
# The FunctionDef node itself has a Scope object that tracks the creation
|
||||
# of its name, along with the usage of any decorator accompanying it.
|
||||
|
@ -518,6 +518,36 @@ class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
|
||||
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(body_scope, ('global_b', 'c'), ('global_a',))
|
||||
|
||||
def test_class_definition_basic(self):
|
||||
|
||||
def test_fn(a, b):
|
||||
class C(a(b)):
|
||||
d = 1
|
||||
return C
|
||||
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
fn_node = node
|
||||
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(body_scope, ('a', 'b', 'C'), ('C',))
|
||||
|
||||
def test_class_definition_isolates_method_writes_but_not_reads(self):
|
||||
|
||||
def test_fn(a, b, c):
|
||||
class C(a(b)):
|
||||
d = 1
|
||||
|
||||
def e(self):
|
||||
f = c + 1
|
||||
return f
|
||||
return C
|
||||
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
fn_node = node
|
||||
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
# Note: 'f' is in there because we cannot detect thattically that it
|
||||
# is local to the function itself.
|
||||
self.assertScopeIs(body_scope, ('a', 'b', 'c', 'f', 'C'), ('C',))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user