diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py index 427bde67852..14e16234107 100644 --- a/tensorflow/python/autograph/pyct/cfg.py +++ b/tensorflow/python/autograph/pyct/cfg.py @@ -74,6 +74,8 @@ class Node(object): def __repr__(self): if isinstance(self.ast_node, gast.FunctionDef): return 'def %s' % self.ast_node.name + elif isinstance(self.ast_node, gast.ClassDef): + return 'class %s' % self.ast_node.name elif isinstance(self.ast_node, gast.withitem): return compiler.ast_to_source(self.ast_node.context_expr).strip() return compiler.ast_to_source(self.ast_node).strip() @@ -659,6 +661,34 @@ class AstToCfg(gast.NodeVisitor): (node, loops_to_nodes_of_type)) self.builder.add_continue_node(node, try_node, guards) + def visit_ClassDef(self, node): + # We also keep the ClassDef node in the CFG, since it technically is a + # statement. + # For example, this is legal and allows executing user code: + # + # class Foo(bar()): + # pass + # + # It also has a scope: + # + # class Bar(object): + # a = 1 + if self.builder is None: + self.generic_visit(node) + return + + self.builder.add_ordinary_node(node) + + self.builder_stack.append(self.builder) + self.builder = GraphBuilder(node) + self._enter_lexical_scope(node) + + self._process_basic_statement(node) + + self._exit_lexical_scope(node) + # TODO(mdan): Track the CFG local to the class definition as well? + self.builder = self.builder_stack.pop() + def visit_FunctionDef(self, node): # We also keep the FunctionDef node in the CFG. This allows us to determine # things like reaching definitions via closure. Note that the function body diff --git a/tensorflow/python/autograph/pyct/cfg_test.py b/tensorflow/python/autograph/pyct/cfg_test.py index 72ef9ee6865..4a95f25caa1 100644 --- a/tensorflow/python/autograph/pyct/cfg_test.py +++ b/tensorflow/python/autograph/pyct/cfg_test.py @@ -1268,6 +1268,40 @@ class AstToCfgTest(test.TestCase): ), ) + def test_class_definition_empty(self): + + def test_fn(a, b): + class C(a(b)): + pass + return C + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a, b', 'class C', 'return C'), + ('class C', 'return C', None), + ), + ) + + def test_class_definition_with_members(self): + + def test_fn(a, b): + class C(a(b)): + d = 1 + return C + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a, b', 'class C', 'return C'), + ('class C', 'return C', None), + ), + ) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py index 0048e778492..5931198620f 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py @@ -400,6 +400,22 @@ class ActivityAnalyzer(transformer.Base): def visit_arguments(self, node): return self._process_statement(node) + def visit_ClassDef(self, node): + # The ClassDef node itself has a Scope object that tracks the creation + # of its name, along with the usage of any decorator accompanying it. + self._enter_scope(False) + node.decorator_list = self.visit_block(node.decorator_list) + self.scope.mark_modified(qual_names.QN(node.name)) + anno.setanno(node, anno.Static.SCOPE, self.scope) + self._exit_scope() + + # A separate Scope tracks the actual class definition. + self._enter_scope(True) + assert not (self._in_function_def_args or self.state[_Lambda].level) + node = self.generic_visit(node) + self._exit_scope() + return node + def visit_FunctionDef(self, node): # The FunctionDef node itself has a Scope object that tracks the creation # of its name, along with the usage of any decorator accompanying it. diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py index 4462fbea187..e1596673825 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py @@ -518,6 +518,36 @@ class ActivityAnalyzerTest(ActivityAnalyzerTestBase): body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) self.assertScopeIs(body_scope, ('global_b', 'c'), ('global_a',)) + def test_class_definition_basic(self): + + def test_fn(a, b): + class C(a(b)): + d = 1 + return C + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node + body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) + self.assertScopeIs(body_scope, ('a', 'b', 'C'), ('C',)) + + def test_class_definition_isolates_method_writes_but_not_reads(self): + + def test_fn(a, b, c): + class C(a(b)): + d = 1 + + def e(self): + f = c + 1 + return f + return C + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node + body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) + # Note: 'f' is in there because we cannot detect thattically that it + # is local to the function itself. + self.assertScopeIs(body_scope, ('a', 'b', 'c', 'f', 'C'), ('C',)) + if __name__ == '__main__': test.main()