From c79154ef495a2fa66b0257dbfa3b36a3a68daf8d Mon Sep 17 00:00:00 2001 From: Brian Lee Date: Mon, 18 Mar 2019 17:29:28 -0700 Subject: [PATCH] Preserve __future__ imports throughout the conversion process. PiperOrigin-RevId: 239095963 --- .../autograph/converters/directives_test.py | 2 +- .../autograph/core/converter_testing.py | 2 +- .../python/autograph/impl/conversion.py | 30 ++++++++++++-- .../python/autograph/impl/conversion_test.py | 14 +++---- tensorflow/python/autograph/pyct/cfg_test.py | 4 +- .../pyct/common_transformers/anf_test.py | 6 +-- .../python/autograph/pyct/compiler_test.py | 4 +- .../python/autograph/pyct/origin_info_test.py | 11 ++--- tensorflow/python/autograph/pyct/parser.py | 41 +++++++++---------- .../python/autograph/pyct/parser_test.py | 39 ++++-------------- .../pyct/static_analysis/activity_test.py | 2 +- .../pyct/static_analysis/live_values_test.py | 2 +- .../pyct/static_analysis/liveness_test.py | 2 +- .../reaching_definitions_test.py | 2 +- .../pyct/static_analysis/type_info_test.py | 2 +- .../python/autograph/pyct/transformer_test.py | 20 ++++----- 16 files changed, 92 insertions(+), 91 deletions(-) diff --git a/tensorflow/python/autograph/converters/directives_test.py b/tensorflow/python/autograph/converters/directives_test.py index 870a491ccdf..22e3e9d5cfa 100644 --- a/tensorflow/python/autograph/converters/directives_test.py +++ b/tensorflow/python/autograph/converters/directives_test.py @@ -84,7 +84,7 @@ class DirectivesTest(converter_testing.TestCase): def call_invalid_directive(): invalid_directive(1) - node, _, _ = parser.parse_entity(call_invalid_directive) + node, _, _ = parser.parse_entity(call_invalid_directive, future_imports=()) # Find the call to the invalid directive node = node.body[0].value with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'): diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index e2d95b89095..eae03cfd548 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -122,7 +122,7 @@ class TestCase(test.TestCase): def prepare(self, test_fn, namespace, arg_types=None, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions - node, source, _ = parser.parse_entity(test_fn) + node, _, source = parser.parse_entity(test_fn, future_imports=()) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index bb9464c3361..09c262a38c9 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -235,6 +235,21 @@ def class_to_graph(c, program_ctx): if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) + # TODO(mdan): Don't clobber namespaces for each method in one class namespace. + # The assumption that one namespace suffices for all methods only holds if + # all methods were defined in the same module. + # If, instead, functions are imported from multiple modules and then spliced + # into the class, then each function has its own globals and __future__ + # imports that need to stay separate. + + # For example, C's methods could both have `global x` statements referring to + # mod1.x and mod2.x, but using one namespace for C would cause a conflict. + # from mod1 import f1 + # from mod2 import f2 + # class C(object): + # method1 = f1 + # method2 = f2 + class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. @@ -250,7 +265,13 @@ def class_to_graph(c, program_ctx): class_namespace = namespace else: class_namespace.update(namespace) - converted_members[m] = nodes[0] + # TODO(brianklee): function_to_graph returns future import nodes and the + # converted function nodes. We discard all the future import nodes here + # which is buggy behavior, but really, the whole approach of gathering all + # of the converted function nodes in one place is intrinsically buggy. + # So this is a reminder to properly handle the future import nodes when we + # redo our approach to class conversion. + converted_members[m] = nodes[-1] namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) @@ -335,8 +356,11 @@ def _add_self_references(namespace, autograph_module): def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True): """Specialization of `entity_to_graph` for callable functions.""" - node, source, _ = parser.parse_entity(f) + future_imports = inspect_utils.getfutureimports(f) + node, future_import_nodes, source = parser.parse_entity( + f, future_imports=future_imports) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) + # Parsed AST should contain future imports and one function def node. # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number @@ -386,7 +410,7 @@ def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True): new_name = f.__name__ assert node.name == new_name - return [node], new_name, namespace + return future_import_nodes + [node], new_name, namespace def node_to_graph(node, context): diff --git a/tensorflow/python/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py index 7902fa697f6..6480091fbf6 100644 --- a/tensorflow/python/autograph/impl/conversion_test.py +++ b/tensorflow/python/autograph/impl/conversion_test.py @@ -58,7 +58,7 @@ class ConversionTest(test.TestCase): program_ctx = self._simple_program_ctx() nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) - fn_node, _ = nodes + fn_node = nodes[-2] self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) self.assertIs(ns['b'], b) @@ -71,7 +71,7 @@ class ConversionTest(test.TestCase): program_ctx = self._simple_program_ctx() nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None) - fn_node, _ = nodes + fn_node = nodes[-2] self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) self.assertEqual( @@ -87,7 +87,7 @@ class ConversionTest(test.TestCase): program_ctx = self._simple_program_ctx() nodes, _, _ = conversion.entity_to_graph(f, program_ctx, None, None) - f_node = nodes[0] + f_node = nodes[-2] self.assertEqual('tf__f', f_node.name) def test_entity_to_graph_class_hierarchy(self): @@ -144,7 +144,7 @@ class ConversionTest(test.TestCase): program_ctx = self._simple_program_ctx() nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) - fn_node, _ = nodes + fn_node = nodes[-2] self.assertIsInstance(fn_node, gast.Assign) self.assertIsInstance(fn_node.value, gast.Lambda) self.assertEqual('tf__lambda', name) @@ -156,7 +156,7 @@ class ConversionTest(test.TestCase): program_ctx = self._simple_program_ctx() nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) - fn_node, _ = nodes + fn_node = nodes[-2] self.assertIsInstance(fn_node, gast.Assign) self.assertIsInstance(fn_node.value, gast.Lambda) self.assertEqual('tf__lambda', name) @@ -179,7 +179,7 @@ class ConversionTest(test.TestCase): program_ctx = self._simple_program_ctx() nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None) - fn_node, _ = nodes + fn_node = nodes[-2] self.assertIsInstance(fn_node, gast.Assign) self.assertIsInstance(fn_node.value, gast.Lambda) self.assertEqual('tf__lambda', name) @@ -194,7 +194,7 @@ class ConversionTest(test.TestCase): program_ctx = self._simple_program_ctx() nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) - fn_node, _ = nodes + fn_node = nodes[-2] self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual(fn_node.name, 'tf__f') self.assertEqual('tf__f', name) diff --git a/tensorflow/python/autograph/pyct/cfg_test.py b/tensorflow/python/autograph/pyct/cfg_test.py index 8fb66ca7a76..6748a77aa47 100644 --- a/tensorflow/python/autograph/pyct/cfg_test.py +++ b/tensorflow/python/autograph/pyct/cfg_test.py @@ -40,7 +40,7 @@ class CountingVisitor(cfg.GraphVisitor): class GraphVisitorTest(test.TestCase): def _build_cfg(self, fn): - node, _, _ = parser.parse_entity(fn) + node, _, _ = parser.parse_entity(fn, future_imports=()) cfgs = cfg.build(node) return cfgs, node @@ -91,7 +91,7 @@ class GraphVisitorTest(test.TestCase): class AstToCfgTest(test.TestCase): def _build_cfg(self, fn): - node, _, _ = parser.parse_entity(fn) + node, _, _ = parser.parse_entity(fn, future_imports=()) cfgs = cfg.build(node) return cfgs diff --git a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py index d7750604778..ee68c5887d3 100644 --- a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py +++ b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py @@ -81,7 +81,7 @@ class AnfTransformerTest(test.TestCase): def test_function(): a = 0 return a - node, _, _ = parser.parse_entity(test_function) + node, _, _ = parser.parse_entity(test_function, future_imports=()) node = anf.transform(node, self._simple_context()) result, _ = compiler.ast_to_object(node) self.assertEqual(test_function(), result.test_function()) @@ -97,8 +97,8 @@ class AnfTransformerTest(test.TestCase): # Testing the code bodies only. Wrapping them in functions so the # syntax highlights nicely, but Python doesn't try to execute the # statements. - exp_node, _, _ = parser.parse_entity(expected_fn) - node, _, _ = parser.parse_entity(test_fn) + exp_node, _, _ = parser.parse_entity(expected_fn, future_imports=()) + node, _, _ = parser.parse_entity(test_fn, future_imports=()) node = anf.transform( node, self._simple_context(), gensym_source=DummyGensym) exp_name = exp_node.name diff --git a/tensorflow/python/autograph/pyct/compiler_test.py b/tensorflow/python/autograph/pyct/compiler_test.py index 29e8a198fe6..490566cc291 100644 --- a/tensorflow/python/autograph/pyct/compiler_test.py +++ b/tensorflow/python/autograph/pyct/compiler_test.py @@ -39,12 +39,12 @@ class CompilerTest(test.TestCase): b = x + 1 return b - _, _, all_nodes = parser.parse_entity(test_fn) + node, _, _ = parser.parse_entity(test_fn, future_imports=()) self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource( - compiler.ast_to_object(all_nodes)[0].test_fn)) + compiler.ast_to_object([node])[0].test_fn)) def test_ast_to_source(self): node = gast.If( diff --git a/tensorflow/python/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py index a3dc2f82716..e386c86a6b0 100644 --- a/tensorflow/python/autograph/pyct/origin_info_test.py +++ b/tensorflow/python/autograph/pyct/origin_info_test.py @@ -32,7 +32,7 @@ class OriginInfoTest(test.TestCase): def test_fn(x): return x + 1 - node, _, _ = parser.parse_entity(test_fn) + node, _, _ = parser.parse_entity(test_fn, future_imports=()) fake_origin = origin_info.OriginInfo( loc=origin_info.Location('fake_filename', 3, 7), function_name='fake_function_name', @@ -53,7 +53,7 @@ class OriginInfoTest(test.TestCase): def test_fn(x): return x + 1 - node, _, _ = parser.parse_entity(test_fn) + node, _, _ = parser.parse_entity(test_fn, future_imports=()) converted_code = compiler.ast_to_source(node) source_map = origin_info.create_source_map( @@ -67,7 +67,7 @@ class OriginInfoTest(test.TestCase): """Docstring.""" return x # comment - node, source, _ = parser.parse_entity(test_fn) + node, _, source = parser.parse_entity(test_fn, future_imports=()) origin_info.resolve(node, source) @@ -89,14 +89,15 @@ class OriginInfoTest(test.TestCase): self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment') - def disabled_test_resolve_with_future_imports(self): + def test_resolve_with_future_imports(self): def test_fn(x): """Docstring.""" print(x) return x # comment - node, source, _ = parser.parse_entity(test_fn) + node, _, source = parser.parse_entity( + test_fn, future_imports=['print_function']) origin_info.resolve(node, source) diff --git a/tensorflow/python/autograph/pyct/parser.py b/tensorflow/python/autograph/pyct/parser.py index d6f51741809..d3c4780c794 100644 --- a/tensorflow/python/autograph/pyct/parser.py +++ b/tensorflow/python/autograph/pyct/parser.py @@ -21,12 +21,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import re +import itertools import textwrap import threading import gast -import six from tensorflow.python.util import tf_inspect @@ -34,17 +33,18 @@ from tensorflow.python.util import tf_inspect _parse_lock = threading.Lock() # Prevents linecache concurrency errors. -def parse_entity(entity): +def parse_entity(entity, future_imports): """Returns the AST and source code of given entity. Args: entity: A python function/method/class + future_imports: An iterable of future imports to use when parsing AST. (e.g. + ('print_statement', 'division', 'unicode_literals')) Returns: - gast.AST, str, gast.ModuleNode: a tuple of the AST node corresponding - exactly to the entity; the string that was parsed to generate the AST; and - the containing module AST node, which might contain extras like future - import nodes. + gast.AST, List[gast.AST], str: a tuple of the AST node corresponding + exactly to the entity; a list of future import AST nodes, and the string + that was parsed to generate the AST. """ try: with _parse_lock: @@ -67,11 +67,13 @@ def parse_entity(entity): # causing textwrap.dedent to not correctly dedent source code. # TODO(b/115884650): Automatic handling of comments/multiline strings. source = textwrap.dedent(source) + future_import_strings = ('from __future__ import {}'.format(name) + for name in future_imports) + source = '\n'.join(itertools.chain(future_import_strings, [source])) try: module_node = parse_str(source) - assert len(module_node.body) == 1 - return module_node.body[0], source, module_node + return _select_entity_node(module_node, source, future_imports) except IndentationError: # The text below lists the causes of this error known to us. There may @@ -112,7 +114,7 @@ def parse_entity(entity): try: module_node = parse_str(new_source) - return module_node.body[0], new_source, module_node + return _select_entity_node(module_node, new_source, future_imports) except SyntaxError as e: raise_parse_failure( 'If this is a lambda function, the error may be avoided by creating' @@ -122,18 +124,7 @@ def parse_entity(entity): def parse_str(src): """Returns the AST of given piece of code.""" - # TODO(mdan): This should exclude the module things are autowrapped in. - - if six.PY2 and re.search('\\Wprint\\s*\\(', src): - # This special treatment is required because gast.parse is not aware of - # whether print_function was present in the original context. - src = 'from __future__ import print_function\n' + src - parsed_module = gast.parse(src) - parsed_module.body = parsed_module.body[1:] - else: - parsed_module = gast.parse(src) - - return parsed_module + return gast.parse(src) def parse_expression(src): @@ -152,3 +143,9 @@ def parse_expression(src): raise ValueError( 'Expected a single expression, found instead %s' % node.body) return node.body[0].value + + +def _select_entity_node(module_node, source, future_imports): + assert len(module_node.body) == 1 + len(future_imports) + return module_node.body[-1], module_node.body[:-1], source + diff --git a/tensorflow/python/autograph/pyct/parser_test.py b/tensorflow/python/autograph/pyct/parser_test.py index ee3e2808259..1f1bda07677 100644 --- a/tensorflow/python/autograph/pyct/parser_test.py +++ b/tensorflow/python/autograph/pyct/parser_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import textwrap - from tensorflow.python.autograph.pyct import parser from tensorflow.python.platform import test @@ -31,41 +29,22 @@ class ParserTest(test.TestCase): def f(x): return x + 1 - node, _, _ = parser.parse_entity(f) + node, _, _ = parser.parse_entity(f, future_imports=()) self.assertEqual('f', node.name) - def test_parse_str(self): - mod = parser.parse_str( - textwrap.dedent(""" - def f(x): - return x + 1 - """)) - self.assertEqual('f', mod.body[0].name) - - def test_parse_str_print(self): - mod = parser.parse_str( - textwrap.dedent(""" - def f(x): - print(x) - return x + 1 - """)) - self.assertEqual('f', mod.body[0].name) - - def test_parse_str_weird_print(self): - mod = parser.parse_str( - textwrap.dedent(""" - def f(x): - print (x) - return x + 1 - """)) - self.assertEqual('f', mod.body[0].name) + def test_parse_entity_print_function(self): + def f(x): + print(x) + node, _, _ = parser.parse_entity( + f, future_imports=['print_function']) + self.assertEqual('f', node.name) def test_parse_comments(self): def f(): # unindented comment pass with self.assertRaises(ValueError): - parser.parse_entity(f) + parser.parse_entity(f, future_imports=()) def test_parse_multiline_strings(self): def f(): @@ -74,7 +53,7 @@ some multiline string""") with self.assertRaises(ValueError): - parser.parse_entity(f) + parser.parse_entity(f, future_imports=()) def test_parse_expression(self): node = parser.parse_expression('a.b') diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py index ef3390e03fa..63e62583c92 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py @@ -112,7 +112,7 @@ class ScopeTest(test.TestCase): class ActivityAnalyzerTest(test.TestCase): def _parse_and_analyze(self, test_fn): - node, source, _ = parser.parse_entity(test_fn) + node, _, source = parser.parse_entity(test_fn, future_imports=()) entity_info = transformer.EntityInfo( source_code=source, source_file=None, diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py index 14bb3682e3b..8010ce9268c 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py @@ -41,7 +41,7 @@ class LiveValuesResolverTest(test.TestCase): literals=None, arg_types=None): literals = literals or {} - node, source, _ = parser.parse_entity(test_fn) + node, _, source = parser.parse_entity(test_fn, future_imports=()) entity_info = transformer.EntityInfo( source_code=source, source_file=None, diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py index c32abb9efd1..2017c4fc9e3 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py @@ -33,7 +33,7 @@ from tensorflow.python.platform import test class LivenessTest(test.TestCase): def _parse_and_analyze(self, test_fn): - node, source, _ = parser.parse_entity(test_fn) + node, _, source = parser.parse_entity(test_fn, future_imports=()) entity_info = transformer.EntityInfo( source_code=source, source_file=None, diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py index 3359886f50d..2b051f170a7 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py @@ -33,7 +33,7 @@ from tensorflow.python.platform import test class DefinitionInfoTest(test.TestCase): def _parse_and_analyze(self, test_fn): - node, source, _ = parser.parse_entity(test_fn) + node, _, source = parser.parse_entity(test_fn, future_imports=()) entity_info = transformer.EntityInfo( source_code=source, source_file=None, diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py index 42e52a6b3b9..15f1ed4ff0e 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py @@ -62,7 +62,7 @@ class TypeInfoResolverTest(test.TestCase): test_fn, namespace, arg_types=None): - node, source, _ = parser.parse_entity(test_fn) + node, _, source = parser.parse_entity(test_fn, future_imports=()) entity_info = transformer.EntityInfo( source_code=source, source_file=None, diff --git a/tensorflow/python/autograph/pyct/transformer_test.py b/tensorflow/python/autograph/pyct/transformer_test.py index bd19ebad5c5..89a44e5d58f 100644 --- a/tensorflow/python/autograph/pyct/transformer_test.py +++ b/tensorflow/python/autograph/pyct/transformer_test.py @@ -68,7 +68,7 @@ class TransformerTest(test.TestCase): return b, inner_function return a, TestClass - node, _, _ = parser.parse_entity(test_function) + node, _, _ = parser.parse_entity(test_function, future_imports=()) node = tr.visit(node) test_function_node = node @@ -141,7 +141,7 @@ class TransformerTest(test.TestCase): while True: raise '1' - node, _, _ = parser.parse_entity(test_function) + node, _, _ = parser.parse_entity(test_function, future_imports=()) node = tr.visit(node) fn_body = node.body @@ -207,7 +207,7 @@ class TransformerTest(test.TestCase): raise '1' return 'nor this' - node, _, _ = parser.parse_entity(test_function) + node, _, _ = parser.parse_entity(test_function, future_imports=()) node = tr.visit(node) for_node = node.body[2] @@ -238,7 +238,7 @@ class TransformerTest(test.TestCase): print(a) return None - node, _, _ = parser.parse_entity(no_exit) + node, _, _ = parser.parse_entity(no_exit, future_imports=()) with self.assertRaises(AssertionError): tr.visit(node) @@ -246,7 +246,7 @@ class TransformerTest(test.TestCase): for _ in a: print(a) - node, _, _ = parser.parse_entity(no_entry) + node, _, _ = parser.parse_entity(no_entry, future_imports=()) with self.assertRaises(AssertionError): tr.visit(node) @@ -272,7 +272,7 @@ class TransformerTest(test.TestCase): tr = TestTransformer(self._simple_context()) - node, _, _ = parser.parse_entity(test_function) + node, _, _ = parser.parse_entity(test_function, future_imports=()) node = tr.visit(node) self.assertEqual(len(node.body), 2) @@ -302,9 +302,9 @@ class TransformerTest(test.TestCase): tr = BrokenTransformer(self._simple_context()) - _, _, all_nodes = parser.parse_entity(test_function) + node, _, _ = parser.parse_entity(test_function, future_imports=()) with self.assertRaises(ValueError) as cm: - all_nodes = tr.visit(all_nodes) + node = tr.visit(node) obtained_message = str(cm.exception) expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"' self.assertRegexpMatches(obtained_message, expected_message) @@ -333,9 +333,9 @@ class TransformerTest(test.TestCase): tr = BrokenTransformer(self._simple_context()) - _, _, all_nodes = parser.parse_entity(test_function) + node, _, _ = parser.parse_entity(test_function, future_imports=()) with self.assertRaises(ValueError) as cm: - all_nodes = tr.visit(all_nodes) + node = tr.visit(node) obtained_message = str(cm.exception) # The message should reference the exception actually raised, not anything # from the exception handler.