Preserve __future__ imports throughout the conversion process.

PiperOrigin-RevId: 239095963
This commit is contained in:
Brian Lee 2019-03-18 17:29:28 -07:00 committed by TensorFlower Gardener
parent b7a1100569
commit c79154ef49
16 changed files with 92 additions and 91 deletions

View File

@ -84,7 +84,7 @@ class DirectivesTest(converter_testing.TestCase):
def call_invalid_directive(): def call_invalid_directive():
invalid_directive(1) invalid_directive(1)
node, _, _ = parser.parse_entity(call_invalid_directive) node, _, _ = parser.parse_entity(call_invalid_directive, future_imports=())
# Find the call to the invalid directive # Find the call to the invalid directive
node = node.body[0].value node = node.body[0].value
with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'): with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):

View File

@ -122,7 +122,7 @@ class TestCase(test.TestCase):
def prepare(self, test_fn, namespace, arg_types=None, recursive=True): def prepare(self, test_fn, namespace, arg_types=None, recursive=True):
namespace['ConversionOptions'] = converter.ConversionOptions namespace['ConversionOptions'] = converter.ConversionOptions
node, source, _ = parser.parse_entity(test_fn) node, _, source = parser.parse_entity(test_fn, future_imports=())
namer = naming.Namer(namespace) namer = naming.Namer(namespace)
program_ctx = converter.ProgramContext( program_ctx = converter.ProgramContext(
options=converter.ConversionOptions(recursive=recursive), options=converter.ConversionOptions(recursive=recursive),

View File

@ -235,6 +235,21 @@ def class_to_graph(c, program_ctx):
if not members: if not members:
raise ValueError('Cannot convert %s: it has no member methods.' % c) raise ValueError('Cannot convert %s: it has no member methods.' % c)
# TODO(mdan): Don't clobber namespaces for each method in one class namespace.
# The assumption that one namespace suffices for all methods only holds if
# all methods were defined in the same module.
# If, instead, functions are imported from multiple modules and then spliced
# into the class, then each function has its own globals and __future__
# imports that need to stay separate.
# For example, C's methods could both have `global x` statements referring to
# mod1.x and mod2.x, but using one namespace for C would cause a conflict.
# from mod1 import f1
# from mod2 import f2
# class C(object):
# method1 = f1
# method2 = f2
class_namespace = {} class_namespace = {}
for _, m in members: for _, m in members:
# Only convert the members that are directly defined by the class. # Only convert the members that are directly defined by the class.
@ -250,7 +265,13 @@ def class_to_graph(c, program_ctx):
class_namespace = namespace class_namespace = namespace
else: else:
class_namespace.update(namespace) class_namespace.update(namespace)
converted_members[m] = nodes[0] # TODO(brianklee): function_to_graph returns future import nodes and the
# converted function nodes. We discard all the future import nodes here
# which is buggy behavior, but really, the whole approach of gathering all
# of the converted function nodes in one place is intrinsically buggy.
# So this is a reminder to properly handle the future import nodes when we
# redo our approach to class conversion.
converted_members[m] = nodes[-1]
namer = naming.Namer(class_namespace) namer = naming.Namer(class_namespace)
class_name = namer.class_name(c.__name__) class_name = namer.class_name(c.__name__)
@ -335,8 +356,11 @@ def _add_self_references(namespace, autograph_module):
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True): def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
"""Specialization of `entity_to_graph` for callable functions.""" """Specialization of `entity_to_graph` for callable functions."""
node, source, _ = parser.parse_entity(f) future_imports = inspect_utils.getfutureimports(f)
node, future_import_nodes, source = parser.parse_entity(
f, future_imports=future_imports)
logging.log(3, 'Source code of %s:\n\n%s\n', f, source) logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
# Parsed AST should contain future imports and one function def node.
# In general, the output of inspect.getsource is inexact for lambdas because # In general, the output of inspect.getsource is inexact for lambdas because
# it uses regex matching to adjust the exact location around the line number # it uses regex matching to adjust the exact location around the line number
@ -386,7 +410,7 @@ def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
new_name = f.__name__ new_name = f.__name__
assert node.name == new_name assert node.name == new_name
return [node], new_name, namespace return future_import_nodes + [node], new_name, namespace
def node_to_graph(node, context): def node_to_graph(node, context):

View File

@ -58,7 +58,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx() program_ctx = self._simple_program_ctx()
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes fn_node = nodes[-2]
self.assertIsInstance(fn_node, gast.FunctionDef) self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual('tf__f', name) self.assertEqual('tf__f', name)
self.assertIs(ns['b'], b) self.assertIs(ns['b'], b)
@ -71,7 +71,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx() program_ctx = self._simple_program_ctx()
nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None) nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes fn_node = nodes[-2]
self.assertIsInstance(fn_node, gast.FunctionDef) self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual('tf__f', name) self.assertEqual('tf__f', name)
self.assertEqual( self.assertEqual(
@ -87,7 +87,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx() program_ctx = self._simple_program_ctx()
nodes, _, _ = conversion.entity_to_graph(f, program_ctx, None, None) nodes, _, _ = conversion.entity_to_graph(f, program_ctx, None, None)
f_node = nodes[0] f_node = nodes[-2]
self.assertEqual('tf__f', f_node.name) self.assertEqual('tf__f', f_node.name)
def test_entity_to_graph_class_hierarchy(self): def test_entity_to_graph_class_hierarchy(self):
@ -144,7 +144,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx() program_ctx = self._simple_program_ctx()
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes fn_node = nodes[-2]
self.assertIsInstance(fn_node, gast.Assign) self.assertIsInstance(fn_node, gast.Assign)
self.assertIsInstance(fn_node.value, gast.Lambda) self.assertIsInstance(fn_node.value, gast.Lambda)
self.assertEqual('tf__lambda', name) self.assertEqual('tf__lambda', name)
@ -156,7 +156,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx() program_ctx = self._simple_program_ctx()
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes fn_node = nodes[-2]
self.assertIsInstance(fn_node, gast.Assign) self.assertIsInstance(fn_node, gast.Assign)
self.assertIsInstance(fn_node.value, gast.Lambda) self.assertIsInstance(fn_node.value, gast.Lambda)
self.assertEqual('tf__lambda', name) self.assertEqual('tf__lambda', name)
@ -179,7 +179,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx() program_ctx = self._simple_program_ctx()
nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None) nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes fn_node = nodes[-2]
self.assertIsInstance(fn_node, gast.Assign) self.assertIsInstance(fn_node, gast.Assign)
self.assertIsInstance(fn_node.value, gast.Lambda) self.assertIsInstance(fn_node.value, gast.Lambda)
self.assertEqual('tf__lambda', name) self.assertEqual('tf__lambda', name)
@ -194,7 +194,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx() program_ctx = self._simple_program_ctx()
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
fn_node, _ = nodes fn_node = nodes[-2]
self.assertIsInstance(fn_node, gast.FunctionDef) self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual(fn_node.name, 'tf__f') self.assertEqual(fn_node.name, 'tf__f')
self.assertEqual('tf__f', name) self.assertEqual('tf__f', name)

View File

@ -40,7 +40,7 @@ class CountingVisitor(cfg.GraphVisitor):
class GraphVisitorTest(test.TestCase): class GraphVisitorTest(test.TestCase):
def _build_cfg(self, fn): def _build_cfg(self, fn):
node, _, _ = parser.parse_entity(fn) node, _, _ = parser.parse_entity(fn, future_imports=())
cfgs = cfg.build(node) cfgs = cfg.build(node)
return cfgs, node return cfgs, node
@ -91,7 +91,7 @@ class GraphVisitorTest(test.TestCase):
class AstToCfgTest(test.TestCase): class AstToCfgTest(test.TestCase):
def _build_cfg(self, fn): def _build_cfg(self, fn):
node, _, _ = parser.parse_entity(fn) node, _, _ = parser.parse_entity(fn, future_imports=())
cfgs = cfg.build(node) cfgs = cfg.build(node)
return cfgs return cfgs

View File

@ -81,7 +81,7 @@ class AnfTransformerTest(test.TestCase):
def test_function(): def test_function():
a = 0 a = 0
return a return a
node, _, _ = parser.parse_entity(test_function) node, _, _ = parser.parse_entity(test_function, future_imports=())
node = anf.transform(node, self._simple_context()) node = anf.transform(node, self._simple_context())
result, _ = compiler.ast_to_object(node) result, _ = compiler.ast_to_object(node)
self.assertEqual(test_function(), result.test_function()) self.assertEqual(test_function(), result.test_function())
@ -97,8 +97,8 @@ class AnfTransformerTest(test.TestCase):
# Testing the code bodies only. Wrapping them in functions so the # Testing the code bodies only. Wrapping them in functions so the
# syntax highlights nicely, but Python doesn't try to execute the # syntax highlights nicely, but Python doesn't try to execute the
# statements. # statements.
exp_node, _, _ = parser.parse_entity(expected_fn) exp_node, _, _ = parser.parse_entity(expected_fn, future_imports=())
node, _, _ = parser.parse_entity(test_fn) node, _, _ = parser.parse_entity(test_fn, future_imports=())
node = anf.transform( node = anf.transform(
node, self._simple_context(), gensym_source=DummyGensym) node, self._simple_context(), gensym_source=DummyGensym)
exp_name = exp_node.name exp_name = exp_node.name

View File

@ -39,12 +39,12 @@ class CompilerTest(test.TestCase):
b = x + 1 b = x + 1
return b return b
_, _, all_nodes = parser.parse_entity(test_fn) node, _, _ = parser.parse_entity(test_fn, future_imports=())
self.assertEqual( self.assertEqual(
textwrap.dedent(tf_inspect.getsource(test_fn)), textwrap.dedent(tf_inspect.getsource(test_fn)),
tf_inspect.getsource( tf_inspect.getsource(
compiler.ast_to_object(all_nodes)[0].test_fn)) compiler.ast_to_object([node])[0].test_fn))
def test_ast_to_source(self): def test_ast_to_source(self):
node = gast.If( node = gast.If(

View File

@ -32,7 +32,7 @@ class OriginInfoTest(test.TestCase):
def test_fn(x): def test_fn(x):
return x + 1 return x + 1
node, _, _ = parser.parse_entity(test_fn) node, _, _ = parser.parse_entity(test_fn, future_imports=())
fake_origin = origin_info.OriginInfo( fake_origin = origin_info.OriginInfo(
loc=origin_info.Location('fake_filename', 3, 7), loc=origin_info.Location('fake_filename', 3, 7),
function_name='fake_function_name', function_name='fake_function_name',
@ -53,7 +53,7 @@ class OriginInfoTest(test.TestCase):
def test_fn(x): def test_fn(x):
return x + 1 return x + 1
node, _, _ = parser.parse_entity(test_fn) node, _, _ = parser.parse_entity(test_fn, future_imports=())
converted_code = compiler.ast_to_source(node) converted_code = compiler.ast_to_source(node)
source_map = origin_info.create_source_map( source_map = origin_info.create_source_map(
@ -67,7 +67,7 @@ class OriginInfoTest(test.TestCase):
"""Docstring.""" """Docstring."""
return x # comment return x # comment
node, source, _ = parser.parse_entity(test_fn) node, _, source = parser.parse_entity(test_fn, future_imports=())
origin_info.resolve(node, source) origin_info.resolve(node, source)
@ -89,14 +89,15 @@ class OriginInfoTest(test.TestCase):
self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.source_code_line, ' return x # comment')
self.assertEqual(origin.comment, 'comment') self.assertEqual(origin.comment, 'comment')
def disabled_test_resolve_with_future_imports(self): def test_resolve_with_future_imports(self):
def test_fn(x): def test_fn(x):
"""Docstring.""" """Docstring."""
print(x) print(x)
return x # comment return x # comment
node, source, _ = parser.parse_entity(test_fn) node, _, source = parser.parse_entity(
test_fn, future_imports=['print_function'])
origin_info.resolve(node, source) origin_info.resolve(node, source)

View File

@ -21,12 +21,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import re import itertools
import textwrap import textwrap
import threading import threading
import gast import gast
import six
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
@ -34,17 +33,18 @@ from tensorflow.python.util import tf_inspect
_parse_lock = threading.Lock() # Prevents linecache concurrency errors. _parse_lock = threading.Lock() # Prevents linecache concurrency errors.
def parse_entity(entity): def parse_entity(entity, future_imports):
"""Returns the AST and source code of given entity. """Returns the AST and source code of given entity.
Args: Args:
entity: A python function/method/class entity: A python function/method/class
future_imports: An iterable of future imports to use when parsing AST. (e.g.
('print_statement', 'division', 'unicode_literals'))
Returns: Returns:
gast.AST, str, gast.ModuleNode: a tuple of the AST node corresponding gast.AST, List[gast.AST], str: a tuple of the AST node corresponding
exactly to the entity; the string that was parsed to generate the AST; and exactly to the entity; a list of future import AST nodes, and the string
the containing module AST node, which might contain extras like future that was parsed to generate the AST.
import nodes.
""" """
try: try:
with _parse_lock: with _parse_lock:
@ -67,11 +67,13 @@ def parse_entity(entity):
# causing textwrap.dedent to not correctly dedent source code. # causing textwrap.dedent to not correctly dedent source code.
# TODO(b/115884650): Automatic handling of comments/multiline strings. # TODO(b/115884650): Automatic handling of comments/multiline strings.
source = textwrap.dedent(source) source = textwrap.dedent(source)
future_import_strings = ('from __future__ import {}'.format(name)
for name in future_imports)
source = '\n'.join(itertools.chain(future_import_strings, [source]))
try: try:
module_node = parse_str(source) module_node = parse_str(source)
assert len(module_node.body) == 1 return _select_entity_node(module_node, source, future_imports)
return module_node.body[0], source, module_node
except IndentationError: except IndentationError:
# The text below lists the causes of this error known to us. There may # The text below lists the causes of this error known to us. There may
@ -112,7 +114,7 @@ def parse_entity(entity):
try: try:
module_node = parse_str(new_source) module_node = parse_str(new_source)
return module_node.body[0], new_source, module_node return _select_entity_node(module_node, new_source, future_imports)
except SyntaxError as e: except SyntaxError as e:
raise_parse_failure( raise_parse_failure(
'If this is a lambda function, the error may be avoided by creating' 'If this is a lambda function, the error may be avoided by creating'
@ -122,18 +124,7 @@ def parse_entity(entity):
def parse_str(src): def parse_str(src):
"""Returns the AST of given piece of code.""" """Returns the AST of given piece of code."""
# TODO(mdan): This should exclude the module things are autowrapped in. return gast.parse(src)
if six.PY2 and re.search('\\Wprint\\s*\\(', src):
# This special treatment is required because gast.parse is not aware of
# whether print_function was present in the original context.
src = 'from __future__ import print_function\n' + src
parsed_module = gast.parse(src)
parsed_module.body = parsed_module.body[1:]
else:
parsed_module = gast.parse(src)
return parsed_module
def parse_expression(src): def parse_expression(src):
@ -152,3 +143,9 @@ def parse_expression(src):
raise ValueError( raise ValueError(
'Expected a single expression, found instead %s' % node.body) 'Expected a single expression, found instead %s' % node.body)
return node.body[0].value return node.body[0].value
def _select_entity_node(module_node, source, future_imports):
assert len(module_node.body) == 1 + len(future_imports)
return module_node.body[-1], module_node.body[:-1], source

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import textwrap
from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -31,41 +29,22 @@ class ParserTest(test.TestCase):
def f(x): def f(x):
return x + 1 return x + 1
node, _, _ = parser.parse_entity(f) node, _, _ = parser.parse_entity(f, future_imports=())
self.assertEqual('f', node.name) self.assertEqual('f', node.name)
def test_parse_str(self): def test_parse_entity_print_function(self):
mod = parser.parse_str( def f(x):
textwrap.dedent(""" print(x)
def f(x): node, _, _ = parser.parse_entity(
return x + 1 f, future_imports=['print_function'])
""")) self.assertEqual('f', node.name)
self.assertEqual('f', mod.body[0].name)
def test_parse_str_print(self):
mod = parser.parse_str(
textwrap.dedent("""
def f(x):
print(x)
return x + 1
"""))
self.assertEqual('f', mod.body[0].name)
def test_parse_str_weird_print(self):
mod = parser.parse_str(
textwrap.dedent("""
def f(x):
print (x)
return x + 1
"""))
self.assertEqual('f', mod.body[0].name)
def test_parse_comments(self): def test_parse_comments(self):
def f(): def f():
# unindented comment # unindented comment
pass pass
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
parser.parse_entity(f) parser.parse_entity(f, future_imports=())
def test_parse_multiline_strings(self): def test_parse_multiline_strings(self):
def f(): def f():
@ -74,7 +53,7 @@ some
multiline multiline
string""") string""")
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
parser.parse_entity(f) parser.parse_entity(f, future_imports=())
def test_parse_expression(self): def test_parse_expression(self):
node = parser.parse_expression('a.b') node = parser.parse_expression('a.b')

View File

@ -112,7 +112,7 @@ class ScopeTest(test.TestCase):
class ActivityAnalyzerTest(test.TestCase): class ActivityAnalyzerTest(test.TestCase):
def _parse_and_analyze(self, test_fn): def _parse_and_analyze(self, test_fn):
node, source, _ = parser.parse_entity(test_fn) node, _, source = parser.parse_entity(test_fn, future_imports=())
entity_info = transformer.EntityInfo( entity_info = transformer.EntityInfo(
source_code=source, source_code=source,
source_file=None, source_file=None,

View File

@ -41,7 +41,7 @@ class LiveValuesResolverTest(test.TestCase):
literals=None, literals=None,
arg_types=None): arg_types=None):
literals = literals or {} literals = literals or {}
node, source, _ = parser.parse_entity(test_fn) node, _, source = parser.parse_entity(test_fn, future_imports=())
entity_info = transformer.EntityInfo( entity_info = transformer.EntityInfo(
source_code=source, source_code=source,
source_file=None, source_file=None,

View File

@ -33,7 +33,7 @@ from tensorflow.python.platform import test
class LivenessTest(test.TestCase): class LivenessTest(test.TestCase):
def _parse_and_analyze(self, test_fn): def _parse_and_analyze(self, test_fn):
node, source, _ = parser.parse_entity(test_fn) node, _, source = parser.parse_entity(test_fn, future_imports=())
entity_info = transformer.EntityInfo( entity_info = transformer.EntityInfo(
source_code=source, source_code=source,
source_file=None, source_file=None,

View File

@ -33,7 +33,7 @@ from tensorflow.python.platform import test
class DefinitionInfoTest(test.TestCase): class DefinitionInfoTest(test.TestCase):
def _parse_and_analyze(self, test_fn): def _parse_and_analyze(self, test_fn):
node, source, _ = parser.parse_entity(test_fn) node, _, source = parser.parse_entity(test_fn, future_imports=())
entity_info = transformer.EntityInfo( entity_info = transformer.EntityInfo(
source_code=source, source_code=source,
source_file=None, source_file=None,

View File

@ -62,7 +62,7 @@ class TypeInfoResolverTest(test.TestCase):
test_fn, test_fn,
namespace, namespace,
arg_types=None): arg_types=None):
node, source, _ = parser.parse_entity(test_fn) node, _, source = parser.parse_entity(test_fn, future_imports=())
entity_info = transformer.EntityInfo( entity_info = transformer.EntityInfo(
source_code=source, source_code=source,
source_file=None, source_file=None,

View File

@ -68,7 +68,7 @@ class TransformerTest(test.TestCase):
return b, inner_function return b, inner_function
return a, TestClass return a, TestClass
node, _, _ = parser.parse_entity(test_function) node, _, _ = parser.parse_entity(test_function, future_imports=())
node = tr.visit(node) node = tr.visit(node)
test_function_node = node test_function_node = node
@ -141,7 +141,7 @@ class TransformerTest(test.TestCase):
while True: while True:
raise '1' raise '1'
node, _, _ = parser.parse_entity(test_function) node, _, _ = parser.parse_entity(test_function, future_imports=())
node = tr.visit(node) node = tr.visit(node)
fn_body = node.body fn_body = node.body
@ -207,7 +207,7 @@ class TransformerTest(test.TestCase):
raise '1' raise '1'
return 'nor this' return 'nor this'
node, _, _ = parser.parse_entity(test_function) node, _, _ = parser.parse_entity(test_function, future_imports=())
node = tr.visit(node) node = tr.visit(node)
for_node = node.body[2] for_node = node.body[2]
@ -238,7 +238,7 @@ class TransformerTest(test.TestCase):
print(a) print(a)
return None return None
node, _, _ = parser.parse_entity(no_exit) node, _, _ = parser.parse_entity(no_exit, future_imports=())
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
tr.visit(node) tr.visit(node)
@ -246,7 +246,7 @@ class TransformerTest(test.TestCase):
for _ in a: for _ in a:
print(a) print(a)
node, _, _ = parser.parse_entity(no_entry) node, _, _ = parser.parse_entity(no_entry, future_imports=())
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
tr.visit(node) tr.visit(node)
@ -272,7 +272,7 @@ class TransformerTest(test.TestCase):
tr = TestTransformer(self._simple_context()) tr = TestTransformer(self._simple_context())
node, _, _ = parser.parse_entity(test_function) node, _, _ = parser.parse_entity(test_function, future_imports=())
node = tr.visit(node) node = tr.visit(node)
self.assertEqual(len(node.body), 2) self.assertEqual(len(node.body), 2)
@ -302,9 +302,9 @@ class TransformerTest(test.TestCase):
tr = BrokenTransformer(self._simple_context()) tr = BrokenTransformer(self._simple_context())
_, _, all_nodes = parser.parse_entity(test_function) node, _, _ = parser.parse_entity(test_function, future_imports=())
with self.assertRaises(ValueError) as cm: with self.assertRaises(ValueError) as cm:
all_nodes = tr.visit(all_nodes) node = tr.visit(node)
obtained_message = str(cm.exception) obtained_message = str(cm.exception)
expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"' expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
self.assertRegexpMatches(obtained_message, expected_message) self.assertRegexpMatches(obtained_message, expected_message)
@ -333,9 +333,9 @@ class TransformerTest(test.TestCase):
tr = BrokenTransformer(self._simple_context()) tr = BrokenTransformer(self._simple_context())
_, _, all_nodes = parser.parse_entity(test_function) node, _, _ = parser.parse_entity(test_function, future_imports=())
with self.assertRaises(ValueError) as cm: with self.assertRaises(ValueError) as cm:
all_nodes = tr.visit(all_nodes) node = tr.visit(node)
obtained_message = str(cm.exception) obtained_message = str(cm.exception)
# The message should reference the exception actually raised, not anything # The message should reference the exception actually raised, not anything
# from the exception handler. # from the exception handler.