Preserve __future__ imports throughout the conversion process.

PiperOrigin-RevId: 239095963
This commit is contained in:
Brian Lee 2019-03-18 17:29:28 -07:00 committed by TensorFlower Gardener
parent b7a1100569
commit c79154ef49
16 changed files with 92 additions and 91 deletions

View File

@ -84,7 +84,7 @@ class DirectivesTest(converter_testing.TestCase):
def call_invalid_directive():
invalid_directive(1)
node, _, _ = parser.parse_entity(call_invalid_directive)
node, _, _ = parser.parse_entity(call_invalid_directive, future_imports=())
# Find the call to the invalid directive
node = node.body[0].value
with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):

View File

@ -122,7 +122,7 @@ class TestCase(test.TestCase):
def prepare(self, test_fn, namespace, arg_types=None, recursive=True):
namespace['ConversionOptions'] = converter.ConversionOptions
node, source, _ = parser.parse_entity(test_fn)
node, _, source = parser.parse_entity(test_fn, future_imports=())
namer = naming.Namer(namespace)
program_ctx = converter.ProgramContext(
options=converter.ConversionOptions(recursive=recursive),

View File

@ -235,6 +235,21 @@ def class_to_graph(c, program_ctx):
if not members:
raise ValueError('Cannot convert %s: it has no member methods.' % c)
# TODO(mdan): Don't clobber namespaces for each method in one class namespace.
# The assumption that one namespace suffices for all methods only holds if
# all methods were defined in the same module.
# If, instead, functions are imported from multiple modules and then spliced
# into the class, then each function has its own globals and __future__
# imports that need to stay separate.
# For example, C's methods could both have `global x` statements referring to
# mod1.x and mod2.x, but using one namespace for C would cause a conflict.
# from mod1 import f1
# from mod2 import f2
# class C(object):
# method1 = f1
# method2 = f2
class_namespace = {}
for _, m in members:
# Only convert the members that are directly defined by the class.
@ -250,7 +265,13 @@ def class_to_graph(c, program_ctx):
class_namespace = namespace
else:
class_namespace.update(namespace)
converted_members[m] = nodes[0]
# TODO(brianklee): function_to_graph returns future import nodes and the
# converted function nodes. We discard all the future import nodes here
# which is buggy behavior, but really, the whole approach of gathering all
# of the converted function nodes in one place is intrinsically buggy.
# So this is a reminder to properly handle the future import nodes when we
# redo our approach to class conversion.
converted_members[m] = nodes[-1]
namer = naming.Namer(class_namespace)
class_name = namer.class_name(c.__name__)
@ -335,8 +356,11 @@ def _add_self_references(namespace, autograph_module):
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
"""Specialization of `entity_to_graph` for callable functions."""
node, source, _ = parser.parse_entity(f)
future_imports = inspect_utils.getfutureimports(f)
node, future_import_nodes, source = parser.parse_entity(
f, future_imports=future_imports)
logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
# Parsed AST should contain future imports and one function def node.
# In general, the output of inspect.getsource is inexact for lambdas because
# it uses regex matching to adjust the exact location around the line number
@ -386,7 +410,7 @@ def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
new_name = f.__name__
assert node.name == new_name
return [node], new_name, namespace
return future_import_nodes + [node], new_name, namespace
def node_to_graph(node, context):

View File

@ -58,7 +58,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx()
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes
fn_node = nodes[-2]
self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual('tf__f', name)
self.assertIs(ns['b'], b)
@ -71,7 +71,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx()
nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes
fn_node = nodes[-2]
self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual('tf__f', name)
self.assertEqual(
@ -87,7 +87,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx()
nodes, _, _ = conversion.entity_to_graph(f, program_ctx, None, None)
f_node = nodes[0]
f_node = nodes[-2]
self.assertEqual('tf__f', f_node.name)
def test_entity_to_graph_class_hierarchy(self):
@ -144,7 +144,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx()
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes
fn_node = nodes[-2]
self.assertIsInstance(fn_node, gast.Assign)
self.assertIsInstance(fn_node.value, gast.Lambda)
self.assertEqual('tf__lambda', name)
@ -156,7 +156,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx()
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes
fn_node = nodes[-2]
self.assertIsInstance(fn_node, gast.Assign)
self.assertIsInstance(fn_node.value, gast.Lambda)
self.assertEqual('tf__lambda', name)
@ -179,7 +179,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx()
nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes
fn_node = nodes[-2]
self.assertIsInstance(fn_node, gast.Assign)
self.assertIsInstance(fn_node.value, gast.Lambda)
self.assertEqual('tf__lambda', name)
@ -194,7 +194,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx()
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes
fn_node = nodes[-2]
self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual(fn_node.name, 'tf__f')
self.assertEqual('tf__f', name)

View File

@ -40,7 +40,7 @@ class CountingVisitor(cfg.GraphVisitor):
class GraphVisitorTest(test.TestCase):
def _build_cfg(self, fn):
node, _, _ = parser.parse_entity(fn)
node, _, _ = parser.parse_entity(fn, future_imports=())
cfgs = cfg.build(node)
return cfgs, node
@ -91,7 +91,7 @@ class GraphVisitorTest(test.TestCase):
class AstToCfgTest(test.TestCase):
def _build_cfg(self, fn):
node, _, _ = parser.parse_entity(fn)
node, _, _ = parser.parse_entity(fn, future_imports=())
cfgs = cfg.build(node)
return cfgs

View File

@ -81,7 +81,7 @@ class AnfTransformerTest(test.TestCase):
def test_function():
a = 0
return a
node, _, _ = parser.parse_entity(test_function)
node, _, _ = parser.parse_entity(test_function, future_imports=())
node = anf.transform(node, self._simple_context())
result, _ = compiler.ast_to_object(node)
self.assertEqual(test_function(), result.test_function())
@ -97,8 +97,8 @@ class AnfTransformerTest(test.TestCase):
# Testing the code bodies only. Wrapping them in functions so the
# syntax highlights nicely, but Python doesn't try to execute the
# statements.
exp_node, _, _ = parser.parse_entity(expected_fn)
node, _, _ = parser.parse_entity(test_fn)
exp_node, _, _ = parser.parse_entity(expected_fn, future_imports=())
node, _, _ = parser.parse_entity(test_fn, future_imports=())
node = anf.transform(
node, self._simple_context(), gensym_source=DummyGensym)
exp_name = exp_node.name

View File

@ -39,12 +39,12 @@ class CompilerTest(test.TestCase):
b = x + 1
return b
_, _, all_nodes = parser.parse_entity(test_fn)
node, _, _ = parser.parse_entity(test_fn, future_imports=())
self.assertEqual(
textwrap.dedent(tf_inspect.getsource(test_fn)),
tf_inspect.getsource(
compiler.ast_to_object(all_nodes)[0].test_fn))
compiler.ast_to_object([node])[0].test_fn))
def test_ast_to_source(self):
node = gast.If(

View File

@ -32,7 +32,7 @@ class OriginInfoTest(test.TestCase):
def test_fn(x):
return x + 1
node, _, _ = parser.parse_entity(test_fn)
node, _, _ = parser.parse_entity(test_fn, future_imports=())
fake_origin = origin_info.OriginInfo(
loc=origin_info.Location('fake_filename', 3, 7),
function_name='fake_function_name',
@ -53,7 +53,7 @@ class OriginInfoTest(test.TestCase):
def test_fn(x):
return x + 1
node, _, _ = parser.parse_entity(test_fn)
node, _, _ = parser.parse_entity(test_fn, future_imports=())
converted_code = compiler.ast_to_source(node)
source_map = origin_info.create_source_map(
@ -67,7 +67,7 @@ class OriginInfoTest(test.TestCase):
"""Docstring."""
return x # comment
node, source, _ = parser.parse_entity(test_fn)
node, _, source = parser.parse_entity(test_fn, future_imports=())
origin_info.resolve(node, source)
@ -89,14 +89,15 @@ class OriginInfoTest(test.TestCase):
self.assertEqual(origin.source_code_line, ' return x # comment')
self.assertEqual(origin.comment, 'comment')
def disabled_test_resolve_with_future_imports(self):
def test_resolve_with_future_imports(self):
def test_fn(x):
"""Docstring."""
print(x)
return x # comment
node, source, _ = parser.parse_entity(test_fn)
node, _, source = parser.parse_entity(
test_fn, future_imports=['print_function'])
origin_info.resolve(node, source)

View File

@ -21,12 +21,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import itertools
import textwrap
import threading
import gast
import six
from tensorflow.python.util import tf_inspect
@ -34,17 +33,18 @@ from tensorflow.python.util import tf_inspect
_parse_lock = threading.Lock() # Prevents linecache concurrency errors.
def parse_entity(entity):
def parse_entity(entity, future_imports):
"""Returns the AST and source code of given entity.
Args:
entity: A python function/method/class
future_imports: An iterable of future imports to use when parsing AST. (e.g.
('print_statement', 'division', 'unicode_literals'))
Returns:
gast.AST, str, gast.ModuleNode: a tuple of the AST node corresponding
exactly to the entity; the string that was parsed to generate the AST; and
the containing module AST node, which might contain extras like future
import nodes.
gast.AST, List[gast.AST], str: a tuple of the AST node corresponding
exactly to the entity; a list of future import AST nodes, and the string
that was parsed to generate the AST.
"""
try:
with _parse_lock:
@ -67,11 +67,13 @@ def parse_entity(entity):
# causing textwrap.dedent to not correctly dedent source code.
# TODO(b/115884650): Automatic handling of comments/multiline strings.
source = textwrap.dedent(source)
future_import_strings = ('from __future__ import {}'.format(name)
for name in future_imports)
source = '\n'.join(itertools.chain(future_import_strings, [source]))
try:
module_node = parse_str(source)
assert len(module_node.body) == 1
return module_node.body[0], source, module_node
return _select_entity_node(module_node, source, future_imports)
except IndentationError:
# The text below lists the causes of this error known to us. There may
@ -112,7 +114,7 @@ def parse_entity(entity):
try:
module_node = parse_str(new_source)
return module_node.body[0], new_source, module_node
return _select_entity_node(module_node, new_source, future_imports)
except SyntaxError as e:
raise_parse_failure(
'If this is a lambda function, the error may be avoided by creating'
@ -122,18 +124,7 @@ def parse_entity(entity):
def parse_str(src):
"""Returns the AST of given piece of code."""
# TODO(mdan): This should exclude the module things are autowrapped in.
if six.PY2 and re.search('\\Wprint\\s*\\(', src):
# This special treatment is required because gast.parse is not aware of
# whether print_function was present in the original context.
src = 'from __future__ import print_function\n' + src
parsed_module = gast.parse(src)
parsed_module.body = parsed_module.body[1:]
else:
parsed_module = gast.parse(src)
return parsed_module
return gast.parse(src)
def parse_expression(src):
@ -152,3 +143,9 @@ def parse_expression(src):
raise ValueError(
'Expected a single expression, found instead %s' % node.body)
return node.body[0].value
def _select_entity_node(module_node, source, future_imports):
assert len(module_node.body) == 1 + len(future_imports)
return module_node.body[-1], module_node.body[:-1], source

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import textwrap
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
@ -31,41 +29,22 @@ class ParserTest(test.TestCase):
def f(x):
return x + 1
node, _, _ = parser.parse_entity(f)
node, _, _ = parser.parse_entity(f, future_imports=())
self.assertEqual('f', node.name)
def test_parse_str(self):
mod = parser.parse_str(
textwrap.dedent("""
def f(x):
return x + 1
"""))
self.assertEqual('f', mod.body[0].name)
def test_parse_str_print(self):
mod = parser.parse_str(
textwrap.dedent("""
def f(x):
print(x)
return x + 1
"""))
self.assertEqual('f', mod.body[0].name)
def test_parse_str_weird_print(self):
mod = parser.parse_str(
textwrap.dedent("""
def f(x):
print (x)
return x + 1
"""))
self.assertEqual('f', mod.body[0].name)
def test_parse_entity_print_function(self):
def f(x):
print(x)
node, _, _ = parser.parse_entity(
f, future_imports=['print_function'])
self.assertEqual('f', node.name)
def test_parse_comments(self):
def f():
# unindented comment
pass
with self.assertRaises(ValueError):
parser.parse_entity(f)
parser.parse_entity(f, future_imports=())
def test_parse_multiline_strings(self):
def f():
@ -74,7 +53,7 @@ some
multiline
string""")
with self.assertRaises(ValueError):
parser.parse_entity(f)
parser.parse_entity(f, future_imports=())
def test_parse_expression(self):
node = parser.parse_expression('a.b')

View File

@ -112,7 +112,7 @@ class ScopeTest(test.TestCase):
class ActivityAnalyzerTest(test.TestCase):
def _parse_and_analyze(self, test_fn):
node, source, _ = parser.parse_entity(test_fn)
node, _, source = parser.parse_entity(test_fn, future_imports=())
entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,

View File

@ -41,7 +41,7 @@ class LiveValuesResolverTest(test.TestCase):
literals=None,
arg_types=None):
literals = literals or {}
node, source, _ = parser.parse_entity(test_fn)
node, _, source = parser.parse_entity(test_fn, future_imports=())
entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,

View File

@ -33,7 +33,7 @@ from tensorflow.python.platform import test
class LivenessTest(test.TestCase):
def _parse_and_analyze(self, test_fn):
node, source, _ = parser.parse_entity(test_fn)
node, _, source = parser.parse_entity(test_fn, future_imports=())
entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,

View File

@ -33,7 +33,7 @@ from tensorflow.python.platform import test
class DefinitionInfoTest(test.TestCase):
def _parse_and_analyze(self, test_fn):
node, source, _ = parser.parse_entity(test_fn)
node, _, source = parser.parse_entity(test_fn, future_imports=())
entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,

View File

@ -62,7 +62,7 @@ class TypeInfoResolverTest(test.TestCase):
test_fn,
namespace,
arg_types=None):
node, source, _ = parser.parse_entity(test_fn)
node, _, source = parser.parse_entity(test_fn, future_imports=())
entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,

View File

@ -68,7 +68,7 @@ class TransformerTest(test.TestCase):
return b, inner_function
return a, TestClass
node, _, _ = parser.parse_entity(test_function)
node, _, _ = parser.parse_entity(test_function, future_imports=())
node = tr.visit(node)
test_function_node = node
@ -141,7 +141,7 @@ class TransformerTest(test.TestCase):
while True:
raise '1'
node, _, _ = parser.parse_entity(test_function)
node, _, _ = parser.parse_entity(test_function, future_imports=())
node = tr.visit(node)
fn_body = node.body
@ -207,7 +207,7 @@ class TransformerTest(test.TestCase):
raise '1'
return 'nor this'
node, _, _ = parser.parse_entity(test_function)
node, _, _ = parser.parse_entity(test_function, future_imports=())
node = tr.visit(node)
for_node = node.body[2]
@ -238,7 +238,7 @@ class TransformerTest(test.TestCase):
print(a)
return None
node, _, _ = parser.parse_entity(no_exit)
node, _, _ = parser.parse_entity(no_exit, future_imports=())
with self.assertRaises(AssertionError):
tr.visit(node)
@ -246,7 +246,7 @@ class TransformerTest(test.TestCase):
for _ in a:
print(a)
node, _, _ = parser.parse_entity(no_entry)
node, _, _ = parser.parse_entity(no_entry, future_imports=())
with self.assertRaises(AssertionError):
tr.visit(node)
@ -272,7 +272,7 @@ class TransformerTest(test.TestCase):
tr = TestTransformer(self._simple_context())
node, _, _ = parser.parse_entity(test_function)
node, _, _ = parser.parse_entity(test_function, future_imports=())
node = tr.visit(node)
self.assertEqual(len(node.body), 2)
@ -302,9 +302,9 @@ class TransformerTest(test.TestCase):
tr = BrokenTransformer(self._simple_context())
_, _, all_nodes = parser.parse_entity(test_function)
node, _, _ = parser.parse_entity(test_function, future_imports=())
with self.assertRaises(ValueError) as cm:
all_nodes = tr.visit(all_nodes)
node = tr.visit(node)
obtained_message = str(cm.exception)
expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
self.assertRegexpMatches(obtained_message, expected_message)
@ -333,9 +333,9 @@ class TransformerTest(test.TestCase):
tr = BrokenTransformer(self._simple_context())
_, _, all_nodes = parser.parse_entity(test_function)
node, _, _ = parser.parse_entity(test_function, future_imports=())
with self.assertRaises(ValueError) as cm:
all_nodes = tr.visit(all_nodes)
node = tr.visit(node)
obtained_message = str(cm.exception)
# The message should reference the exception actually raised, not anything
# from the exception handler.