Preserve __future__ imports throughout the conversion process.
PiperOrigin-RevId: 239095963
This commit is contained in:
parent
b7a1100569
commit
c79154ef49
@ -84,7 +84,7 @@ class DirectivesTest(converter_testing.TestCase):
|
|||||||
def call_invalid_directive():
|
def call_invalid_directive():
|
||||||
invalid_directive(1)
|
invalid_directive(1)
|
||||||
|
|
||||||
node, _, _ = parser.parse_entity(call_invalid_directive)
|
node, _, _ = parser.parse_entity(call_invalid_directive, future_imports=())
|
||||||
# Find the call to the invalid directive
|
# Find the call to the invalid directive
|
||||||
node = node.body[0].value
|
node = node.body[0].value
|
||||||
with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):
|
with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):
|
||||||
|
@ -122,7 +122,7 @@ class TestCase(test.TestCase):
|
|||||||
def prepare(self, test_fn, namespace, arg_types=None, recursive=True):
|
def prepare(self, test_fn, namespace, arg_types=None, recursive=True):
|
||||||
namespace['ConversionOptions'] = converter.ConversionOptions
|
namespace['ConversionOptions'] = converter.ConversionOptions
|
||||||
|
|
||||||
node, source, _ = parser.parse_entity(test_fn)
|
node, _, source = parser.parse_entity(test_fn, future_imports=())
|
||||||
namer = naming.Namer(namespace)
|
namer = naming.Namer(namespace)
|
||||||
program_ctx = converter.ProgramContext(
|
program_ctx = converter.ProgramContext(
|
||||||
options=converter.ConversionOptions(recursive=recursive),
|
options=converter.ConversionOptions(recursive=recursive),
|
||||||
|
@ -235,6 +235,21 @@ def class_to_graph(c, program_ctx):
|
|||||||
if not members:
|
if not members:
|
||||||
raise ValueError('Cannot convert %s: it has no member methods.' % c)
|
raise ValueError('Cannot convert %s: it has no member methods.' % c)
|
||||||
|
|
||||||
|
# TODO(mdan): Don't clobber namespaces for each method in one class namespace.
|
||||||
|
# The assumption that one namespace suffices for all methods only holds if
|
||||||
|
# all methods were defined in the same module.
|
||||||
|
# If, instead, functions are imported from multiple modules and then spliced
|
||||||
|
# into the class, then each function has its own globals and __future__
|
||||||
|
# imports that need to stay separate.
|
||||||
|
|
||||||
|
# For example, C's methods could both have `global x` statements referring to
|
||||||
|
# mod1.x and mod2.x, but using one namespace for C would cause a conflict.
|
||||||
|
# from mod1 import f1
|
||||||
|
# from mod2 import f2
|
||||||
|
# class C(object):
|
||||||
|
# method1 = f1
|
||||||
|
# method2 = f2
|
||||||
|
|
||||||
class_namespace = {}
|
class_namespace = {}
|
||||||
for _, m in members:
|
for _, m in members:
|
||||||
# Only convert the members that are directly defined by the class.
|
# Only convert the members that are directly defined by the class.
|
||||||
@ -250,7 +265,13 @@ def class_to_graph(c, program_ctx):
|
|||||||
class_namespace = namespace
|
class_namespace = namespace
|
||||||
else:
|
else:
|
||||||
class_namespace.update(namespace)
|
class_namespace.update(namespace)
|
||||||
converted_members[m] = nodes[0]
|
# TODO(brianklee): function_to_graph returns future import nodes and the
|
||||||
|
# converted function nodes. We discard all the future import nodes here
|
||||||
|
# which is buggy behavior, but really, the whole approach of gathering all
|
||||||
|
# of the converted function nodes in one place is intrinsically buggy.
|
||||||
|
# So this is a reminder to properly handle the future import nodes when we
|
||||||
|
# redo our approach to class conversion.
|
||||||
|
converted_members[m] = nodes[-1]
|
||||||
namer = naming.Namer(class_namespace)
|
namer = naming.Namer(class_namespace)
|
||||||
class_name = namer.class_name(c.__name__)
|
class_name = namer.class_name(c.__name__)
|
||||||
|
|
||||||
@ -335,8 +356,11 @@ def _add_self_references(namespace, autograph_module):
|
|||||||
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
|
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
|
||||||
"""Specialization of `entity_to_graph` for callable functions."""
|
"""Specialization of `entity_to_graph` for callable functions."""
|
||||||
|
|
||||||
node, source, _ = parser.parse_entity(f)
|
future_imports = inspect_utils.getfutureimports(f)
|
||||||
|
node, future_import_nodes, source = parser.parse_entity(
|
||||||
|
f, future_imports=future_imports)
|
||||||
logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
|
logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
|
||||||
|
# Parsed AST should contain future imports and one function def node.
|
||||||
|
|
||||||
# In general, the output of inspect.getsource is inexact for lambdas because
|
# In general, the output of inspect.getsource is inexact for lambdas because
|
||||||
# it uses regex matching to adjust the exact location around the line number
|
# it uses regex matching to adjust the exact location around the line number
|
||||||
@ -386,7 +410,7 @@ def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
|
|||||||
new_name = f.__name__
|
new_name = f.__name__
|
||||||
assert node.name == new_name
|
assert node.name == new_name
|
||||||
|
|
||||||
return [node], new_name, namespace
|
return future_import_nodes + [node], new_name, namespace
|
||||||
|
|
||||||
|
|
||||||
def node_to_graph(node, context):
|
def node_to_graph(node, context):
|
||||||
|
@ -58,7 +58,7 @@ class ConversionTest(test.TestCase):
|
|||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
program_ctx = self._simple_program_ctx()
|
||||||
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
|
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
|
||||||
fn_node, _ = nodes
|
fn_node = nodes[-2]
|
||||||
self.assertIsInstance(fn_node, gast.FunctionDef)
|
self.assertIsInstance(fn_node, gast.FunctionDef)
|
||||||
self.assertEqual('tf__f', name)
|
self.assertEqual('tf__f', name)
|
||||||
self.assertIs(ns['b'], b)
|
self.assertIs(ns['b'], b)
|
||||||
@ -71,7 +71,7 @@ class ConversionTest(test.TestCase):
|
|||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
program_ctx = self._simple_program_ctx()
|
||||||
nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
|
nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
|
||||||
fn_node, _ = nodes
|
fn_node = nodes[-2]
|
||||||
self.assertIsInstance(fn_node, gast.FunctionDef)
|
self.assertIsInstance(fn_node, gast.FunctionDef)
|
||||||
self.assertEqual('tf__f', name)
|
self.assertEqual('tf__f', name)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -87,7 +87,7 @@ class ConversionTest(test.TestCase):
|
|||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
program_ctx = self._simple_program_ctx()
|
||||||
nodes, _, _ = conversion.entity_to_graph(f, program_ctx, None, None)
|
nodes, _, _ = conversion.entity_to_graph(f, program_ctx, None, None)
|
||||||
f_node = nodes[0]
|
f_node = nodes[-2]
|
||||||
self.assertEqual('tf__f', f_node.name)
|
self.assertEqual('tf__f', f_node.name)
|
||||||
|
|
||||||
def test_entity_to_graph_class_hierarchy(self):
|
def test_entity_to_graph_class_hierarchy(self):
|
||||||
@ -144,7 +144,7 @@ class ConversionTest(test.TestCase):
|
|||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
program_ctx = self._simple_program_ctx()
|
||||||
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
|
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
|
||||||
fn_node, _ = nodes
|
fn_node = nodes[-2]
|
||||||
self.assertIsInstance(fn_node, gast.Assign)
|
self.assertIsInstance(fn_node, gast.Assign)
|
||||||
self.assertIsInstance(fn_node.value, gast.Lambda)
|
self.assertIsInstance(fn_node.value, gast.Lambda)
|
||||||
self.assertEqual('tf__lambda', name)
|
self.assertEqual('tf__lambda', name)
|
||||||
@ -156,7 +156,7 @@ class ConversionTest(test.TestCase):
|
|||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
program_ctx = self._simple_program_ctx()
|
||||||
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
|
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
|
||||||
fn_node, _ = nodes
|
fn_node = nodes[-2]
|
||||||
self.assertIsInstance(fn_node, gast.Assign)
|
self.assertIsInstance(fn_node, gast.Assign)
|
||||||
self.assertIsInstance(fn_node.value, gast.Lambda)
|
self.assertIsInstance(fn_node.value, gast.Lambda)
|
||||||
self.assertEqual('tf__lambda', name)
|
self.assertEqual('tf__lambda', name)
|
||||||
@ -179,7 +179,7 @@ class ConversionTest(test.TestCase):
|
|||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
program_ctx = self._simple_program_ctx()
|
||||||
nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
|
nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
|
||||||
fn_node, _ = nodes
|
fn_node = nodes[-2]
|
||||||
self.assertIsInstance(fn_node, gast.Assign)
|
self.assertIsInstance(fn_node, gast.Assign)
|
||||||
self.assertIsInstance(fn_node.value, gast.Lambda)
|
self.assertIsInstance(fn_node.value, gast.Lambda)
|
||||||
self.assertEqual('tf__lambda', name)
|
self.assertEqual('tf__lambda', name)
|
||||||
@ -194,7 +194,7 @@ class ConversionTest(test.TestCase):
|
|||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
program_ctx = self._simple_program_ctx()
|
||||||
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
|
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
|
||||||
fn_node, _ = nodes
|
fn_node = nodes[-2]
|
||||||
self.assertIsInstance(fn_node, gast.FunctionDef)
|
self.assertIsInstance(fn_node, gast.FunctionDef)
|
||||||
self.assertEqual(fn_node.name, 'tf__f')
|
self.assertEqual(fn_node.name, 'tf__f')
|
||||||
self.assertEqual('tf__f', name)
|
self.assertEqual('tf__f', name)
|
||||||
|
@ -40,7 +40,7 @@ class CountingVisitor(cfg.GraphVisitor):
|
|||||||
class GraphVisitorTest(test.TestCase):
|
class GraphVisitorTest(test.TestCase):
|
||||||
|
|
||||||
def _build_cfg(self, fn):
|
def _build_cfg(self, fn):
|
||||||
node, _, _ = parser.parse_entity(fn)
|
node, _, _ = parser.parse_entity(fn, future_imports=())
|
||||||
cfgs = cfg.build(node)
|
cfgs = cfg.build(node)
|
||||||
return cfgs, node
|
return cfgs, node
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ class GraphVisitorTest(test.TestCase):
|
|||||||
class AstToCfgTest(test.TestCase):
|
class AstToCfgTest(test.TestCase):
|
||||||
|
|
||||||
def _build_cfg(self, fn):
|
def _build_cfg(self, fn):
|
||||||
node, _, _ = parser.parse_entity(fn)
|
node, _, _ = parser.parse_entity(fn, future_imports=())
|
||||||
cfgs = cfg.build(node)
|
cfgs = cfg.build(node)
|
||||||
return cfgs
|
return cfgs
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ class AnfTransformerTest(test.TestCase):
|
|||||||
def test_function():
|
def test_function():
|
||||||
a = 0
|
a = 0
|
||||||
return a
|
return a
|
||||||
node, _, _ = parser.parse_entity(test_function)
|
node, _, _ = parser.parse_entity(test_function, future_imports=())
|
||||||
node = anf.transform(node, self._simple_context())
|
node = anf.transform(node, self._simple_context())
|
||||||
result, _ = compiler.ast_to_object(node)
|
result, _ = compiler.ast_to_object(node)
|
||||||
self.assertEqual(test_function(), result.test_function())
|
self.assertEqual(test_function(), result.test_function())
|
||||||
@ -97,8 +97,8 @@ class AnfTransformerTest(test.TestCase):
|
|||||||
# Testing the code bodies only. Wrapping them in functions so the
|
# Testing the code bodies only. Wrapping them in functions so the
|
||||||
# syntax highlights nicely, but Python doesn't try to execute the
|
# syntax highlights nicely, but Python doesn't try to execute the
|
||||||
# statements.
|
# statements.
|
||||||
exp_node, _, _ = parser.parse_entity(expected_fn)
|
exp_node, _, _ = parser.parse_entity(expected_fn, future_imports=())
|
||||||
node, _, _ = parser.parse_entity(test_fn)
|
node, _, _ = parser.parse_entity(test_fn, future_imports=())
|
||||||
node = anf.transform(
|
node = anf.transform(
|
||||||
node, self._simple_context(), gensym_source=DummyGensym)
|
node, self._simple_context(), gensym_source=DummyGensym)
|
||||||
exp_name = exp_node.name
|
exp_name = exp_node.name
|
||||||
|
@ -39,12 +39,12 @@ class CompilerTest(test.TestCase):
|
|||||||
b = x + 1
|
b = x + 1
|
||||||
return b
|
return b
|
||||||
|
|
||||||
_, _, all_nodes = parser.parse_entity(test_fn)
|
node, _, _ = parser.parse_entity(test_fn, future_imports=())
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
textwrap.dedent(tf_inspect.getsource(test_fn)),
|
textwrap.dedent(tf_inspect.getsource(test_fn)),
|
||||||
tf_inspect.getsource(
|
tf_inspect.getsource(
|
||||||
compiler.ast_to_object(all_nodes)[0].test_fn))
|
compiler.ast_to_object([node])[0].test_fn))
|
||||||
|
|
||||||
def test_ast_to_source(self):
|
def test_ast_to_source(self):
|
||||||
node = gast.If(
|
node = gast.If(
|
||||||
|
@ -32,7 +32,7 @@ class OriginInfoTest(test.TestCase):
|
|||||||
def test_fn(x):
|
def test_fn(x):
|
||||||
return x + 1
|
return x + 1
|
||||||
|
|
||||||
node, _, _ = parser.parse_entity(test_fn)
|
node, _, _ = parser.parse_entity(test_fn, future_imports=())
|
||||||
fake_origin = origin_info.OriginInfo(
|
fake_origin = origin_info.OriginInfo(
|
||||||
loc=origin_info.Location('fake_filename', 3, 7),
|
loc=origin_info.Location('fake_filename', 3, 7),
|
||||||
function_name='fake_function_name',
|
function_name='fake_function_name',
|
||||||
@ -53,7 +53,7 @@ class OriginInfoTest(test.TestCase):
|
|||||||
def test_fn(x):
|
def test_fn(x):
|
||||||
return x + 1
|
return x + 1
|
||||||
|
|
||||||
node, _, _ = parser.parse_entity(test_fn)
|
node, _, _ = parser.parse_entity(test_fn, future_imports=())
|
||||||
converted_code = compiler.ast_to_source(node)
|
converted_code = compiler.ast_to_source(node)
|
||||||
|
|
||||||
source_map = origin_info.create_source_map(
|
source_map = origin_info.create_source_map(
|
||||||
@ -67,7 +67,7 @@ class OriginInfoTest(test.TestCase):
|
|||||||
"""Docstring."""
|
"""Docstring."""
|
||||||
return x # comment
|
return x # comment
|
||||||
|
|
||||||
node, source, _ = parser.parse_entity(test_fn)
|
node, _, source = parser.parse_entity(test_fn, future_imports=())
|
||||||
|
|
||||||
origin_info.resolve(node, source)
|
origin_info.resolve(node, source)
|
||||||
|
|
||||||
@ -89,14 +89,15 @@ class OriginInfoTest(test.TestCase):
|
|||||||
self.assertEqual(origin.source_code_line, ' return x # comment')
|
self.assertEqual(origin.source_code_line, ' return x # comment')
|
||||||
self.assertEqual(origin.comment, 'comment')
|
self.assertEqual(origin.comment, 'comment')
|
||||||
|
|
||||||
def disabled_test_resolve_with_future_imports(self):
|
def test_resolve_with_future_imports(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def test_fn(x):
|
||||||
"""Docstring."""
|
"""Docstring."""
|
||||||
print(x)
|
print(x)
|
||||||
return x # comment
|
return x # comment
|
||||||
|
|
||||||
node, source, _ = parser.parse_entity(test_fn)
|
node, _, source = parser.parse_entity(
|
||||||
|
test_fn, future_imports=['print_function'])
|
||||||
|
|
||||||
origin_info.resolve(node, source)
|
origin_info.resolve(node, source)
|
||||||
|
|
||||||
|
@ -21,12 +21,11 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import re
|
import itertools
|
||||||
import textwrap
|
import textwrap
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import gast
|
import gast
|
||||||
import six
|
|
||||||
|
|
||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
|
|
||||||
@ -34,17 +33,18 @@ from tensorflow.python.util import tf_inspect
|
|||||||
_parse_lock = threading.Lock() # Prevents linecache concurrency errors.
|
_parse_lock = threading.Lock() # Prevents linecache concurrency errors.
|
||||||
|
|
||||||
|
|
||||||
def parse_entity(entity):
|
def parse_entity(entity, future_imports):
|
||||||
"""Returns the AST and source code of given entity.
|
"""Returns the AST and source code of given entity.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
entity: A python function/method/class
|
entity: A python function/method/class
|
||||||
|
future_imports: An iterable of future imports to use when parsing AST. (e.g.
|
||||||
|
('print_statement', 'division', 'unicode_literals'))
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
gast.AST, str, gast.ModuleNode: a tuple of the AST node corresponding
|
gast.AST, List[gast.AST], str: a tuple of the AST node corresponding
|
||||||
exactly to the entity; the string that was parsed to generate the AST; and
|
exactly to the entity; a list of future import AST nodes, and the string
|
||||||
the containing module AST node, which might contain extras like future
|
that was parsed to generate the AST.
|
||||||
import nodes.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with _parse_lock:
|
with _parse_lock:
|
||||||
@ -67,11 +67,13 @@ def parse_entity(entity):
|
|||||||
# causing textwrap.dedent to not correctly dedent source code.
|
# causing textwrap.dedent to not correctly dedent source code.
|
||||||
# TODO(b/115884650): Automatic handling of comments/multiline strings.
|
# TODO(b/115884650): Automatic handling of comments/multiline strings.
|
||||||
source = textwrap.dedent(source)
|
source = textwrap.dedent(source)
|
||||||
|
future_import_strings = ('from __future__ import {}'.format(name)
|
||||||
|
for name in future_imports)
|
||||||
|
source = '\n'.join(itertools.chain(future_import_strings, [source]))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
module_node = parse_str(source)
|
module_node = parse_str(source)
|
||||||
assert len(module_node.body) == 1
|
return _select_entity_node(module_node, source, future_imports)
|
||||||
return module_node.body[0], source, module_node
|
|
||||||
|
|
||||||
except IndentationError:
|
except IndentationError:
|
||||||
# The text below lists the causes of this error known to us. There may
|
# The text below lists the causes of this error known to us. There may
|
||||||
@ -112,7 +114,7 @@ def parse_entity(entity):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
module_node = parse_str(new_source)
|
module_node = parse_str(new_source)
|
||||||
return module_node.body[0], new_source, module_node
|
return _select_entity_node(module_node, new_source, future_imports)
|
||||||
except SyntaxError as e:
|
except SyntaxError as e:
|
||||||
raise_parse_failure(
|
raise_parse_failure(
|
||||||
'If this is a lambda function, the error may be avoided by creating'
|
'If this is a lambda function, the error may be avoided by creating'
|
||||||
@ -122,18 +124,7 @@ def parse_entity(entity):
|
|||||||
|
|
||||||
def parse_str(src):
|
def parse_str(src):
|
||||||
"""Returns the AST of given piece of code."""
|
"""Returns the AST of given piece of code."""
|
||||||
# TODO(mdan): This should exclude the module things are autowrapped in.
|
return gast.parse(src)
|
||||||
|
|
||||||
if six.PY2 and re.search('\\Wprint\\s*\\(', src):
|
|
||||||
# This special treatment is required because gast.parse is not aware of
|
|
||||||
# whether print_function was present in the original context.
|
|
||||||
src = 'from __future__ import print_function\n' + src
|
|
||||||
parsed_module = gast.parse(src)
|
|
||||||
parsed_module.body = parsed_module.body[1:]
|
|
||||||
else:
|
|
||||||
parsed_module = gast.parse(src)
|
|
||||||
|
|
||||||
return parsed_module
|
|
||||||
|
|
||||||
|
|
||||||
def parse_expression(src):
|
def parse_expression(src):
|
||||||
@ -152,3 +143,9 @@ def parse_expression(src):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Expected a single expression, found instead %s' % node.body)
|
'Expected a single expression, found instead %s' % node.body)
|
||||||
return node.body[0].value
|
return node.body[0].value
|
||||||
|
|
||||||
|
|
||||||
|
def _select_entity_node(module_node, source, future_imports):
|
||||||
|
assert len(module_node.body) == 1 + len(future_imports)
|
||||||
|
return module_node.body[-1], module_node.body[:-1], source
|
||||||
|
|
||||||
|
@ -18,8 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import textwrap
|
|
||||||
|
|
||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -31,41 +29,22 @@ class ParserTest(test.TestCase):
|
|||||||
def f(x):
|
def f(x):
|
||||||
return x + 1
|
return x + 1
|
||||||
|
|
||||||
node, _, _ = parser.parse_entity(f)
|
node, _, _ = parser.parse_entity(f, future_imports=())
|
||||||
self.assertEqual('f', node.name)
|
self.assertEqual('f', node.name)
|
||||||
|
|
||||||
def test_parse_str(self):
|
def test_parse_entity_print_function(self):
|
||||||
mod = parser.parse_str(
|
def f(x):
|
||||||
textwrap.dedent("""
|
print(x)
|
||||||
def f(x):
|
node, _, _ = parser.parse_entity(
|
||||||
return x + 1
|
f, future_imports=['print_function'])
|
||||||
"""))
|
self.assertEqual('f', node.name)
|
||||||
self.assertEqual('f', mod.body[0].name)
|
|
||||||
|
|
||||||
def test_parse_str_print(self):
|
|
||||||
mod = parser.parse_str(
|
|
||||||
textwrap.dedent("""
|
|
||||||
def f(x):
|
|
||||||
print(x)
|
|
||||||
return x + 1
|
|
||||||
"""))
|
|
||||||
self.assertEqual('f', mod.body[0].name)
|
|
||||||
|
|
||||||
def test_parse_str_weird_print(self):
|
|
||||||
mod = parser.parse_str(
|
|
||||||
textwrap.dedent("""
|
|
||||||
def f(x):
|
|
||||||
print (x)
|
|
||||||
return x + 1
|
|
||||||
"""))
|
|
||||||
self.assertEqual('f', mod.body[0].name)
|
|
||||||
|
|
||||||
def test_parse_comments(self):
|
def test_parse_comments(self):
|
||||||
def f():
|
def f():
|
||||||
# unindented comment
|
# unindented comment
|
||||||
pass
|
pass
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
parser.parse_entity(f)
|
parser.parse_entity(f, future_imports=())
|
||||||
|
|
||||||
def test_parse_multiline_strings(self):
|
def test_parse_multiline_strings(self):
|
||||||
def f():
|
def f():
|
||||||
@ -74,7 +53,7 @@ some
|
|||||||
multiline
|
multiline
|
||||||
string""")
|
string""")
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
parser.parse_entity(f)
|
parser.parse_entity(f, future_imports=())
|
||||||
|
|
||||||
def test_parse_expression(self):
|
def test_parse_expression(self):
|
||||||
node = parser.parse_expression('a.b')
|
node = parser.parse_expression('a.b')
|
||||||
|
@ -112,7 +112,7 @@ class ScopeTest(test.TestCase):
|
|||||||
class ActivityAnalyzerTest(test.TestCase):
|
class ActivityAnalyzerTest(test.TestCase):
|
||||||
|
|
||||||
def _parse_and_analyze(self, test_fn):
|
def _parse_and_analyze(self, test_fn):
|
||||||
node, source, _ = parser.parse_entity(test_fn)
|
node, _, source = parser.parse_entity(test_fn, future_imports=())
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
source_code=source,
|
source_code=source,
|
||||||
source_file=None,
|
source_file=None,
|
||||||
|
@ -41,7 +41,7 @@ class LiveValuesResolverTest(test.TestCase):
|
|||||||
literals=None,
|
literals=None,
|
||||||
arg_types=None):
|
arg_types=None):
|
||||||
literals = literals or {}
|
literals = literals or {}
|
||||||
node, source, _ = parser.parse_entity(test_fn)
|
node, _, source = parser.parse_entity(test_fn, future_imports=())
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
source_code=source,
|
source_code=source,
|
||||||
source_file=None,
|
source_file=None,
|
||||||
|
@ -33,7 +33,7 @@ from tensorflow.python.platform import test
|
|||||||
class LivenessTest(test.TestCase):
|
class LivenessTest(test.TestCase):
|
||||||
|
|
||||||
def _parse_and_analyze(self, test_fn):
|
def _parse_and_analyze(self, test_fn):
|
||||||
node, source, _ = parser.parse_entity(test_fn)
|
node, _, source = parser.parse_entity(test_fn, future_imports=())
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
source_code=source,
|
source_code=source,
|
||||||
source_file=None,
|
source_file=None,
|
||||||
|
@ -33,7 +33,7 @@ from tensorflow.python.platform import test
|
|||||||
class DefinitionInfoTest(test.TestCase):
|
class DefinitionInfoTest(test.TestCase):
|
||||||
|
|
||||||
def _parse_and_analyze(self, test_fn):
|
def _parse_and_analyze(self, test_fn):
|
||||||
node, source, _ = parser.parse_entity(test_fn)
|
node, _, source = parser.parse_entity(test_fn, future_imports=())
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
source_code=source,
|
source_code=source,
|
||||||
source_file=None,
|
source_file=None,
|
||||||
|
@ -62,7 +62,7 @@ class TypeInfoResolverTest(test.TestCase):
|
|||||||
test_fn,
|
test_fn,
|
||||||
namespace,
|
namespace,
|
||||||
arg_types=None):
|
arg_types=None):
|
||||||
node, source, _ = parser.parse_entity(test_fn)
|
node, _, source = parser.parse_entity(test_fn, future_imports=())
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
source_code=source,
|
source_code=source,
|
||||||
source_file=None,
|
source_file=None,
|
||||||
|
@ -68,7 +68,7 @@ class TransformerTest(test.TestCase):
|
|||||||
return b, inner_function
|
return b, inner_function
|
||||||
return a, TestClass
|
return a, TestClass
|
||||||
|
|
||||||
node, _, _ = parser.parse_entity(test_function)
|
node, _, _ = parser.parse_entity(test_function, future_imports=())
|
||||||
node = tr.visit(node)
|
node = tr.visit(node)
|
||||||
|
|
||||||
test_function_node = node
|
test_function_node = node
|
||||||
@ -141,7 +141,7 @@ class TransformerTest(test.TestCase):
|
|||||||
while True:
|
while True:
|
||||||
raise '1'
|
raise '1'
|
||||||
|
|
||||||
node, _, _ = parser.parse_entity(test_function)
|
node, _, _ = parser.parse_entity(test_function, future_imports=())
|
||||||
node = tr.visit(node)
|
node = tr.visit(node)
|
||||||
|
|
||||||
fn_body = node.body
|
fn_body = node.body
|
||||||
@ -207,7 +207,7 @@ class TransformerTest(test.TestCase):
|
|||||||
raise '1'
|
raise '1'
|
||||||
return 'nor this'
|
return 'nor this'
|
||||||
|
|
||||||
node, _, _ = parser.parse_entity(test_function)
|
node, _, _ = parser.parse_entity(test_function, future_imports=())
|
||||||
node = tr.visit(node)
|
node = tr.visit(node)
|
||||||
|
|
||||||
for_node = node.body[2]
|
for_node = node.body[2]
|
||||||
@ -238,7 +238,7 @@ class TransformerTest(test.TestCase):
|
|||||||
print(a)
|
print(a)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
node, _, _ = parser.parse_entity(no_exit)
|
node, _, _ = parser.parse_entity(no_exit, future_imports=())
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
tr.visit(node)
|
tr.visit(node)
|
||||||
|
|
||||||
@ -246,7 +246,7 @@ class TransformerTest(test.TestCase):
|
|||||||
for _ in a:
|
for _ in a:
|
||||||
print(a)
|
print(a)
|
||||||
|
|
||||||
node, _, _ = parser.parse_entity(no_entry)
|
node, _, _ = parser.parse_entity(no_entry, future_imports=())
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
tr.visit(node)
|
tr.visit(node)
|
||||||
|
|
||||||
@ -272,7 +272,7 @@ class TransformerTest(test.TestCase):
|
|||||||
|
|
||||||
tr = TestTransformer(self._simple_context())
|
tr = TestTransformer(self._simple_context())
|
||||||
|
|
||||||
node, _, _ = parser.parse_entity(test_function)
|
node, _, _ = parser.parse_entity(test_function, future_imports=())
|
||||||
node = tr.visit(node)
|
node = tr.visit(node)
|
||||||
|
|
||||||
self.assertEqual(len(node.body), 2)
|
self.assertEqual(len(node.body), 2)
|
||||||
@ -302,9 +302,9 @@ class TransformerTest(test.TestCase):
|
|||||||
|
|
||||||
tr = BrokenTransformer(self._simple_context())
|
tr = BrokenTransformer(self._simple_context())
|
||||||
|
|
||||||
_, _, all_nodes = parser.parse_entity(test_function)
|
node, _, _ = parser.parse_entity(test_function, future_imports=())
|
||||||
with self.assertRaises(ValueError) as cm:
|
with self.assertRaises(ValueError) as cm:
|
||||||
all_nodes = tr.visit(all_nodes)
|
node = tr.visit(node)
|
||||||
obtained_message = str(cm.exception)
|
obtained_message = str(cm.exception)
|
||||||
expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
|
expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
|
||||||
self.assertRegexpMatches(obtained_message, expected_message)
|
self.assertRegexpMatches(obtained_message, expected_message)
|
||||||
@ -333,9 +333,9 @@ class TransformerTest(test.TestCase):
|
|||||||
|
|
||||||
tr = BrokenTransformer(self._simple_context())
|
tr = BrokenTransformer(self._simple_context())
|
||||||
|
|
||||||
_, _, all_nodes = parser.parse_entity(test_function)
|
node, _, _ = parser.parse_entity(test_function, future_imports=())
|
||||||
with self.assertRaises(ValueError) as cm:
|
with self.assertRaises(ValueError) as cm:
|
||||||
all_nodes = tr.visit(all_nodes)
|
node = tr.visit(node)
|
||||||
obtained_message = str(cm.exception)
|
obtained_message = str(cm.exception)
|
||||||
# The message should reference the exception actually raised, not anything
|
# The message should reference the exception actually raised, not anything
|
||||||
# from the exception handler.
|
# from the exception handler.
|
||||||
|
Loading…
Reference in New Issue
Block a user