Allow local class definitions.

PiperOrigin-RevId: 254756131
This commit is contained in:
Dan Moldovan 2019-06-24 07:47:50 -07:00 committed by TensorFlower Gardener
parent 9a62d2e09e
commit 9365ef2fc3
4 changed files with 110 additions and 0 deletions

View File

@ -74,6 +74,8 @@ class Node(object):
def __repr__(self):
if isinstance(self.ast_node, gast.FunctionDef):
return 'def %s' % self.ast_node.name
elif isinstance(self.ast_node, gast.ClassDef):
return 'class %s' % self.ast_node.name
elif isinstance(self.ast_node, gast.withitem):
return compiler.ast_to_source(self.ast_node.context_expr).strip()
return compiler.ast_to_source(self.ast_node).strip()
@ -659,6 +661,34 @@ class AstToCfg(gast.NodeVisitor):
(node, loops_to_nodes_of_type))
self.builder.add_continue_node(node, try_node, guards)
def visit_ClassDef(self, node):
# We also keep the ClassDef node in the CFG, since it technically is a
# statement.
# For example, this is legal and allows executing user code:
#
# class Foo(bar()):
# pass
#
# It also has a scope:
#
# class Bar(object):
# a = 1
if self.builder is None:
self.generic_visit(node)
return
self.builder.add_ordinary_node(node)
self.builder_stack.append(self.builder)
self.builder = GraphBuilder(node)
self._enter_lexical_scope(node)
self._process_basic_statement(node)
self._exit_lexical_scope(node)
# TODO(mdan): Track the CFG local to the class definition as well?
self.builder = self.builder_stack.pop()
def visit_FunctionDef(self, node):
# We also keep the FunctionDef node in the CFG. This allows us to determine
# things like reaching definitions via closure. Note that the function body

View File

@ -1268,6 +1268,40 @@ class AstToCfgTest(test.TestCase):
),
)
def test_class_definition_empty(self):
def test_fn(a, b):
class C(a(b)):
pass
return C
graph, = self._build_cfg(test_fn).values()
self.assertGraphMatches(
graph,
(
('a, b', 'class C', 'return C'),
('class C', 'return C', None),
),
)
def test_class_definition_with_members(self):
def test_fn(a, b):
class C(a(b)):
d = 1
return C
graph, = self._build_cfg(test_fn).values()
self.assertGraphMatches(
graph,
(
('a, b', 'class C', 'return C'),
('class C', 'return C', None),
),
)
if __name__ == '__main__':
test.main()

View File

@ -400,6 +400,22 @@ class ActivityAnalyzer(transformer.Base):
def visit_arguments(self, node):
return self._process_statement(node)
def visit_ClassDef(self, node):
# The ClassDef node itself has a Scope object that tracks the creation
# of its name, along with the usage of any decorator accompanying it.
self._enter_scope(False)
node.decorator_list = self.visit_block(node.decorator_list)
self.scope.mark_modified(qual_names.QN(node.name))
anno.setanno(node, anno.Static.SCOPE, self.scope)
self._exit_scope()
# A separate Scope tracks the actual class definition.
self._enter_scope(True)
assert not (self._in_function_def_args or self.state[_Lambda].level)
node = self.generic_visit(node)
self._exit_scope()
return node
def visit_FunctionDef(self, node):
# The FunctionDef node itself has a Scope object that tracks the creation
# of its name, along with the usage of any decorator accompanying it.

View File

@ -518,6 +518,36 @@ class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('global_b', 'c'), ('global_a',))
def test_class_definition_basic(self):
def test_fn(a, b):
class C(a(b)):
d = 1
return C
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('a', 'b', 'C'), ('C',))
def test_class_definition_isolates_method_writes_but_not_reads(self):
def test_fn(a, b, c):
class C(a(b)):
d = 1
def e(self):
f = c + 1
return f
return C
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
# Note: 'f' is in there because we cannot detect thattically that it
# is local to the function itself.
self.assertScopeIs(body_scope, ('a', 'b', 'c', 'f', 'C'), ('C',))
if __name__ == '__main__':
test.main()