Refactor parser.parse_entity to separately return node and containing module node

PiperOrigin-RevId: 238480153
This commit is contained in:
Brian Lee 2019-03-14 11:13:01 -07:00 committed by TensorFlower Gardener
parent cf3c25b7fb
commit 852d3364e6
15 changed files with 146 additions and 141 deletions

View File

@ -84,9 +84,9 @@ class DirectivesTest(converter_testing.TestCase):
def call_invalid_directive():
invalid_directive(1)
node, _ = parser.parse_entity(call_invalid_directive)
node, _, _ = parser.parse_entity(call_invalid_directive)
# Find the call to the invalid directive
node = node.body[0].body[0].value
node = node.body[0].value
with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):
directives_converter._map_args(node, invalid_directive)

View File

@ -122,8 +122,7 @@ class TestCase(test.TestCase):
def prepare(self, test_fn, namespace, arg_types=None, recursive=True):
namespace['ConversionOptions'] = converter.ConversionOptions
node, source = parser.parse_entity(test_fn)
node = node.body[0]
node, source, _ = parser.parse_entity(test_fn)
namer = naming.Namer(namespace)
program_ctx = converter.ProgramContext(
options=converter.ConversionOptions(recursive=recursive),

View File

@ -335,9 +335,8 @@ def _add_self_references(namespace, autograph_module):
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
"""Specialization of `entity_to_graph` for callable functions."""
node, source = parser.parse_entity(f)
node, source, _ = parser.parse_entity(f)
logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
node = node.body[0]
# In general, the output of inspect.getsource is inexact for lambdas because
# it uses regex matching to adjust the exact location around the line number

View File

@ -40,7 +40,7 @@ class CountingVisitor(cfg.GraphVisitor):
class GraphVisitorTest(test.TestCase):
def _build_cfg(self, fn):
node, _ = parser.parse_entity(fn)
node, _, _ = parser.parse_entity(fn)
cfgs = cfg.build(node)
return cfgs, node
@ -57,15 +57,14 @@ class GraphVisitorTest(test.TestCase):
graph, = graphs.values()
visitor = CountingVisitor(graph)
visitor.visit_forward()
fn_node = node.body[0]
self.assertEqual(visitor.counts[fn_node.args], 1)
self.assertEqual(visitor.counts[fn_node.body[0].test], 1)
self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1)
self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1)
self.assertEqual(visitor.counts[node.args], 1)
self.assertEqual(visitor.counts[node.body[0].test], 1)
self.assertEqual(visitor.counts[node.body[0].body[0]], 1)
self.assertEqual(visitor.counts[node.body[0].body[1]], 1)
# The return node should be unreachable in forward direction.
self.assertTrue(fn_node.body[0].body[2] not in visitor.counts)
self.assertEqual(visitor.counts[fn_node.body[1]], 1)
self.assertNotIn(node.body[0].body[2], visitor.counts)
self.assertEqual(visitor.counts[node.body[1]], 1)
def test_basic_coverage_reverse(self):
@ -80,20 +79,19 @@ class GraphVisitorTest(test.TestCase):
graph, = graphs.values()
visitor = CountingVisitor(graph)
visitor.visit_reverse()
fn_node = node.body[0]
self.assertEqual(visitor.counts[fn_node.args], 1)
self.assertEqual(visitor.counts[fn_node.body[0].test], 1)
self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1)
self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1)
self.assertTrue(visitor.counts[fn_node.body[0].body[2]], 1)
self.assertEqual(visitor.counts[fn_node.body[1]], 1)
self.assertEqual(visitor.counts[node.args], 1)
self.assertEqual(visitor.counts[node.body[0].test], 1)
self.assertEqual(visitor.counts[node.body[0].body[0]], 1)
self.assertEqual(visitor.counts[node.body[0].body[1]], 1)
self.assertTrue(visitor.counts[node.body[0].body[2]], 1)
self.assertEqual(visitor.counts[node.body[1]], 1)
class AstToCfgTest(test.TestCase):
def _build_cfg(self, fn):
node, _ = parser.parse_entity(fn)
node, _, _ = parser.parse_entity(fn)
cfgs = cfg.build(node)
return cfgs

View File

@ -81,8 +81,8 @@ class AnfTransformerTest(test.TestCase):
def test_function():
a = 0
return a
node, _ = parser.parse_entity(test_function)
node = anf.transform(node.body[0], self._simple_context())
node, _, _ = parser.parse_entity(test_function)
node = anf.transform(node, self._simple_context())
result, _ = compiler.ast_to_object(node)
self.assertEqual(test_function(), result.test_function())
@ -97,15 +97,15 @@ class AnfTransformerTest(test.TestCase):
# Testing the code bodies only. Wrapping them in functions so the
# syntax highlights nicely, but Python doesn't try to execute the
# statements.
exp_node, _ = parser.parse_entity(expected_fn)
node, _ = parser.parse_entity(test_fn)
exp_node, _, _ = parser.parse_entity(expected_fn)
node, _, _ = parser.parse_entity(test_fn)
node = anf.transform(
node, self._simple_context(), gensym_source=DummyGensym)
exp_name = exp_node.body[0].name
exp_name = exp_node.name
# Ignoring the function names in the result because they can't be
# the same (because both functions have to exist in the same scope
# at the same time).
node.body[0].name = exp_name
node.name = exp_name
self.assert_same_ast(exp_node, node)
# Check that ANF is idempotent
node_repeated = anf.transform(

View File

@ -39,11 +39,12 @@ class CompilerTest(test.TestCase):
b = x + 1
return b
_, _, all_nodes = parser.parse_entity(test_fn)
self.assertEqual(
textwrap.dedent(tf_inspect.getsource(test_fn)),
tf_inspect.getsource(
compiler.ast_to_object(
parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
compiler.ast_to_object(all_nodes)[0].test_fn))
def test_ast_to_source(self):
node = gast.If(

View File

@ -32,18 +32,17 @@ class OriginInfoTest(test.TestCase):
def test_fn(x):
return x + 1
node, _ = parser.parse_entity(test_fn)
node, _, _ = parser.parse_entity(test_fn)
fake_origin = origin_info.OriginInfo(
loc=origin_info.Location('fake_filename', 3, 7),
function_name='fake_function_name',
source_code_line='fake source line',
comment=None)
fn_node = node.body[-1]
anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin)
converted_code = compiler.ast_to_source(fn_node)
anno.setanno(node.body[0], anno.Basic.ORIGIN, fake_origin)
converted_code = compiler.ast_to_source(node)
source_map = origin_info.create_source_map(
fn_node, converted_code, 'test_filename', [0])
node, converted_code, 'test_filename', [0])
loc = origin_info.LineLocation('test_filename', 2)
self.assertIn(loc, source_map)
@ -54,12 +53,11 @@ class OriginInfoTest(test.TestCase):
def test_fn(x):
return x + 1
node, _ = parser.parse_entity(test_fn)
fn_node = node.body[-1]
converted_code = compiler.ast_to_source(fn_node)
node, _, _ = parser.parse_entity(test_fn)
converted_code = compiler.ast_to_source(node)
source_map = origin_info.create_source_map(
fn_node, converted_code, 'test_filename', [0])
node, converted_code, 'test_filename', [0])
self.assertEqual(len(source_map), 0)
@ -69,24 +67,23 @@ class OriginInfoTest(test.TestCase):
"""Docstring."""
return x # comment
node, source = parser.parse_entity(test_fn)
fn_node = node.body[0]
node, source, _ = parser.parse_entity(test_fn)
origin_info.resolve(fn_node, source)
origin_info.resolve(node, source)
origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
origin = anno.getanno(node, anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 1)
self.assertEqual(origin.loc.col_offset, 0)
self.assertEqual(origin.source_code_line, 'def test_fn(x):')
self.assertIsNone(origin.comment)
origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 2)
self.assertEqual(origin.loc.col_offset, 2)
self.assertEqual(origin.source_code_line, ' """Docstring."""')
self.assertIsNone(origin.comment)
origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
origin = anno.getanno(node.body[1], anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 3)
self.assertEqual(origin.loc.col_offset, 2)
self.assertEqual(origin.source_code_line, ' return x # comment')
@ -99,24 +96,23 @@ class OriginInfoTest(test.TestCase):
print(x)
return x # comment
node, source = parser.parse_entity(test_fn)
fn_node = node.body[-1]
node, source, _ = parser.parse_entity(test_fn)
origin_info.resolve(fn_node, source)
origin_info.resolve(node, source)
origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
origin = anno.getanno(node, anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 2)
self.assertEqual(origin.loc.col_offset, 0)
self.assertEqual(origin.source_code_line, 'def test_fn(x):')
self.assertIsNone(origin.comment)
origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 3)
self.assertEqual(origin.loc.col_offset, 2)
self.assertEqual(origin.source_code_line, ' """Docstring."""')
self.assertIsNone(origin.comment)
origin = anno.getanno(fn_node.body[2], anno.Basic.ORIGIN)
origin = anno.getanno(node.body[2], anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 5)
self.assertEqual(origin.loc.col_offset, 2)
self.assertEqual(origin.source_code_line, ' return x # comment')

View File

@ -35,7 +35,17 @@ _parse_lock = threading.Lock() # Prevents linecache concurrency errors.
def parse_entity(entity):
"""Returns the AST of given entity."""
"""Returns the AST and source code of given entity.
Args:
entity: A python function/method/class
Returns:
gast.AST, str, gast.ModuleNode: a tuple of the AST node corresponding
exactly to the entity; the string that was parsed to generate the AST; and
the containing module AST node, which might contain extras like future
import nodes.
"""
try:
with _parse_lock:
source = tf_inspect.getsource_no_unwrap(entity)
@ -59,7 +69,9 @@ def parse_entity(entity):
source = textwrap.dedent(source)
try:
return parse_str(source), source
module_node = parse_str(source)
assert len(module_node.body) == 1
return module_node.body[0], source, module_node
except IndentationError:
# The text below lists the causes of this error known to us. There may
@ -99,7 +111,8 @@ def parse_entity(entity):
new_source = '\n'.join(lines)
try:
return parse_str(new_source), new_source
module_node = parse_str(new_source)
return module_node.body[0], new_source, module_node
except SyntaxError as e:
raise_parse_failure(
'If this is a lambda function, the error may be avoided by creating'

View File

@ -31,8 +31,8 @@ class ParserTest(test.TestCase):
def f(x):
return x + 1
mod, _ = parser.parse_entity(f)
self.assertEqual('f', mod.body[0].name)
node, _, _ = parser.parse_entity(f)
self.assertEqual('f', node.name)
def test_parse_str(self):
mod = parser.parse_str(

View File

@ -112,7 +112,7 @@ class ScopeTest(test.TestCase):
class ActivityAnalyzerTest(test.TestCase):
def _parse_and_analyze(self, test_fn):
node, source = parser.parse_entity(test_fn)
node, source, _ = parser.parse_entity(test_fn)
entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,
@ -149,7 +149,7 @@ class ActivityAnalyzerTest(test.TestCase):
return c
node, _ = self._parse_and_analyze(test_fn)
print_node = node.body[0].body[2]
print_node = node.body[2]
if isinstance(print_node, gast.Print):
# Python 2
print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
@ -172,7 +172,7 @@ class ActivityAnalyzerTest(test.TestCase):
return c
node, _ = self._parse_and_analyze(test_fn)
call_node = node.body[0].body[2].value
call_node = node.body[2].value
# We basically need to detect which variables are captured by the call
# arguments.
self.assertScopeIs(
@ -189,7 +189,7 @@ class ActivityAnalyzerTest(test.TestCase):
return a.d
node, _ = self._parse_and_analyze(test_fn)
call_node = node.body[0].body[1].value
call_node = node.body[1].value
self.assertScopeIs(
anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), ())
@ -205,7 +205,7 @@ class ActivityAnalyzerTest(test.TestCase):
return a[c]
node, _ = self._parse_and_analyze(test_fn)
call_node = node.body[0].body[2].value
call_node = node.body[2].value
self.assertScopeIs(
anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
('a', 'a[0]', 'a[b]', 'b'), ())
@ -220,7 +220,7 @@ class ActivityAnalyzerTest(test.TestCase):
return b, c
node, _ = self._parse_and_analyze(test_fn)
while_node = node.body[0].body[1]
while_node = node.body[1]
self.assertScopeIs(
anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'))
self.assertScopeIs(
@ -239,7 +239,7 @@ class ActivityAnalyzerTest(test.TestCase):
return b, c
node, _ = self._parse_and_analyze(test_fn)
for_node = node.body[0].body[1]
for_node = node.body[1]
self.assertScopeIs(
anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'))
self.assertScopeIs(
@ -260,7 +260,7 @@ class ActivityAnalyzerTest(test.TestCase):
return z, u
node, _ = self._parse_and_analyze(test_fn)
if_node = node.body[0].body[0]
if_node = node.body[0]
self.assertScopeIs(
anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'))
self.assertScopeIs(
@ -285,7 +285,7 @@ class ActivityAnalyzerTest(test.TestCase):
return d
node, _ = self._parse_and_analyze(test_fn)
if_node = node.body[0].body[0]
if_node = node.body[0]
self.assertScopeIs(
anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'a.c'), ('a.b', 'd'))
self.assertScopeIs(
@ -307,7 +307,7 @@ class ActivityAnalyzerTest(test.TestCase):
return d
node, _ = self._parse_and_analyze(test_fn)
if_node = node.body[0].body[0]
if_node = node.body[0]
self.assertScopeIs(
anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'b', 'c', 'a[c]'),
('a[b]', 'd'))
@ -329,7 +329,7 @@ class ActivityAnalyzerTest(test.TestCase):
return a
node, _ = self._parse_and_analyze(test_fn)
inner_if_node = node.body[0].body[0].body[0]
inner_if_node = node.body[0].body[0]
self.assertScopeIs(
anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',))
self.assertScopeIs(
@ -350,7 +350,7 @@ class ActivityAnalyzerTest(test.TestCase):
return b, c
node, _ = self._parse_and_analyze(test_fn)
fn_def_node = node.body[0].body[0]
fn_def_node = node.body[0]
self.assertScopeIs(
anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',))
@ -364,7 +364,7 @@ class ActivityAnalyzerTest(test.TestCase):
self.b.c = 1
node, _ = self._parse_and_analyze(TestClass)
init_node = node.body[0].body[0]
init_node = node.body[0]
self.assertScopeIs(
anno.getanno(init_node, NodeAnno.BODY_SCOPE), ('self', 'a', 'self.b'),
('self', 'self.b', 'self.b.c'))
@ -375,7 +375,7 @@ class ActivityAnalyzerTest(test.TestCase):
a[0] += 1
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
fn_node = node
self.assertScopeIs(
anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('a', 'a[0]'), ('a[0]',))
@ -385,7 +385,7 @@ class ActivityAnalyzerTest(test.TestCase):
return c
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
fn_node = node
self.assertScopeIs(anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('c',), ())
def test_aug_assign(self):
@ -394,7 +394,7 @@ class ActivityAnalyzerTest(test.TestCase):
a += b
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
fn_node = node
self.assertScopeIs(
anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('a', 'b'), ('a'))
@ -409,7 +409,7 @@ class ActivityAnalyzerTest(test.TestCase):
foo()['bar'] += x
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
fn_node = node
self.assertScopeIs(
anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('foo', 'x'), ())
@ -419,7 +419,7 @@ class ActivityAnalyzerTest(test.TestCase):
return b
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('b',), ())
self.assertScopeIs(body_scope.parent, ('b',), ('a', 'b'))
@ -433,7 +433,7 @@ class ActivityAnalyzerTest(test.TestCase):
return lambda: a + b
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('a', 'b'), ())
# Nothing local to the lambda is tracked.
@ -445,7 +445,7 @@ class ActivityAnalyzerTest(test.TestCase):
return lambda a: a + b
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('b',), ())
self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
@ -456,7 +456,7 @@ class ActivityAnalyzerTest(test.TestCase):
a = (lambda a, b, c: a + b + c)(d, 1, 2) + b
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('b', 'd'), ('a',))
self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
@ -467,7 +467,7 @@ class ActivityAnalyzerTest(test.TestCase):
a = lambda a, b: d(lambda b: a + b + c) # pylint: disable=undefined-variable
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('c', 'd'), ('a',))
self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')

View File

@ -41,7 +41,7 @@ class LiveValuesResolverTest(test.TestCase):
literals=None,
arg_types=None):
literals = literals or {}
node, source = parser.parse_entity(test_fn)
node, source, _ = parser.parse_entity(test_fn)
entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,
@ -67,7 +67,7 @@ class LiveValuesResolverTest(test.TestCase):
return a
node = self._parse_and_analyze(test_fn, {}, literals={'a': 'bar'})
retval_node = node.body[0].body[0].value
retval_node = node.body[0].value
self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
def test_primitive_values(self):
@ -78,7 +78,7 @@ class LiveValuesResolverTest(test.TestCase):
return a
node = self._parse_and_analyze(test_fn, {'a': True})
retval_node = node.body[0].body[0].value
retval_node = node.body[0].value
if six.PY2:
self.assertEqual(
anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool'))
@ -94,7 +94,7 @@ class LiveValuesResolverTest(test.TestCase):
return foo()
node = self._parse_and_analyze(test_fn, {'foo': foo})
func_node = node.body[0].body[0].value.func
func_node = node.body[0].value.func
self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
@ -104,7 +104,7 @@ class LiveValuesResolverTest(test.TestCase):
return constant_op.constant(0)
node = self._parse_and_analyze(test_fn, {'constant_op': constant_op})
func_node = node.body[0].body[0].value.func
func_node = node.body[0].value.func
self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val'))
self.assertEquals((constant_op.__name__, 'constant'),
anno.getanno(func_node, 'fqn'))
@ -122,7 +122,7 @@ class LiveValuesResolverTest(test.TestCase):
node = self._parse_and_analyze(
TestClass.test_fn, {'constant_op': constant_op},
arg_types={'self': (TestClass.__name__, TestClass)})
func_node = node.body[0].body[0].value.func
func_node = node.body[0].value.func
self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val'))
self.assertEquals(TestClass, anno.getanno(func_node, 'parent_type'))
self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn'))

View File

@ -33,7 +33,7 @@ from tensorflow.python.platform import test
class LivenessTest(test.TestCase):
def _parse_and_analyze(self, test_fn):
node, source = parser.parse_entity(test_fn)
node, source, _ = parser.parse_entity(test_fn)
entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,
@ -75,7 +75,7 @@ class LivenessTest(test.TestCase):
return x
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasLiveOut(fn_body[0], ('a', 'x'))
self.assertHasLiveOut(fn_body[1], 'x')
@ -92,7 +92,7 @@ class LivenessTest(test.TestCase):
return x
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasLiveOut(fn_body[0], 'a')
self.assertHasLiveOut(fn_body[1], 'x')
@ -105,7 +105,7 @@ class LivenessTest(test.TestCase):
return x
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasLiveOut(fn_body[0], 'x')
@ -117,7 +117,7 @@ class LivenessTest(test.TestCase):
return x.y
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasLiveOut(fn_body[0], ('x.y', 'x'))
@ -133,7 +133,7 @@ class LivenessTest(test.TestCase):
foo()
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasLiveOut(fn_body[0], 'a')
@ -151,7 +151,7 @@ class LivenessTest(test.TestCase):
child()
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasLiveOut(fn_body[0], 'max')
@ -165,7 +165,7 @@ class LivenessTest(test.TestCase):
y = 0
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasLiveOut(fn_body[0], ())
@ -179,7 +179,7 @@ class LivenessTest(test.TestCase):
return x
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasLiveIn(fn_body[0], ('a', 'b', 'c', 'x'))
self.assertHasLiveIn(fn_body[1], ('c', 'x'))
@ -196,7 +196,7 @@ class LivenessTest(test.TestCase):
return x
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasLiveIn(fn_body[0], ('a', 'b', 'c', 'd'))
self.assertHasLiveIn(fn_body[1], ('d', 'x'))
@ -211,7 +211,7 @@ class LivenessTest(test.TestCase):
return y, z
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasLiveIn(fn_body[0], ('a', 'y', 'z'))
@ -226,7 +226,7 @@ class LivenessTest(test.TestCase):
return y, z
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasLiveIn(fn_body[0], ('a', 'y', 'z'))
@ -240,7 +240,7 @@ class LivenessTest(test.TestCase):
y = 0
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasLiveIn(fn_body[0], ('a', 'x', 'y'))
@ -251,7 +251,7 @@ class LivenessTest(test.TestCase):
return
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
if six.PY2:
self.assertHasLiveIn(fn_body[0], ('all', 'x', 'y'))
@ -265,7 +265,7 @@ class LivenessTest(test.TestCase):
return
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
if six.PY2:
self.assertHasLiveIn(fn_body[0], ('x', 'y'))
@ -279,7 +279,7 @@ class LivenessTest(test.TestCase):
return
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
if six.PY2:
self.assertHasLiveIn(fn_body[0], ('x', 'y'))
@ -293,7 +293,7 @@ class LivenessTest(test.TestCase):
return
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
if six.PY2:
self.assertHasLiveIn(fn_body[0], ('k', 'v', 'y'))

View File

@ -33,7 +33,7 @@ from tensorflow.python.platform import test
class DefinitionInfoTest(test.TestCase):
def _parse_and_analyze(self, test_fn):
node, source = parser.parse_entity(test_fn)
node, source, _ = parser.parse_entity(test_fn)
entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,
@ -86,7 +86,7 @@ class DefinitionInfoTest(test.TestCase):
return a
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasDefs(fn_body[0].targets[0], 1)
self.assertHasDefs(fn_body[1].test, 1)
@ -105,7 +105,7 @@ class DefinitionInfoTest(test.TestCase):
return a
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasDefs(fn_body[0].value.args[0], 1)
self.assertHasDefs(fn_body[1].body[0].targets[0], 1)
@ -128,7 +128,7 @@ class DefinitionInfoTest(test.TestCase):
return x, y
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasDefs(fn_body[0].targets[0], 1)
self.assertHasDefs(fn_body[1].test, 2)
@ -153,7 +153,7 @@ class DefinitionInfoTest(test.TestCase):
return x, y
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasDefs(fn_body[0].targets[0], 1)
self.assertHasDefs(fn_body[1].target, 1)
@ -178,7 +178,7 @@ class DefinitionInfoTest(test.TestCase):
return a
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
def_of_a_in_if = fn_body[1].body[0].targets[0]
self.assertHasDefs(fn_body[0].targets[0], 1)
@ -202,7 +202,7 @@ class DefinitionInfoTest(test.TestCase):
return a
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
parent_return = fn_body[3]
child_return = fn_body[1].body[1]
@ -219,7 +219,7 @@ class DefinitionInfoTest(test.TestCase):
return a
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
self.assertHasDefs(fn_body[0].items[0].context_expr.func, 0)
self.assertHasDefs(fn_body[0].items[0].context_expr.args[0], 1)
@ -232,7 +232,7 @@ class DefinitionInfoTest(test.TestCase):
return l
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
creation = fn_body[0].targets[0]
mutation = fn_body[1].targets[0].value
@ -251,7 +251,7 @@ class DefinitionInfoTest(test.TestCase):
return a
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
first_def = fn_body[0].targets[0]
second_def = fn_body[1].orelse[0].targets[0]
@ -270,7 +270,7 @@ class DefinitionInfoTest(test.TestCase):
return a
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
use = fn_body[2].value
self.assertHasDefs(use, 0)
@ -285,9 +285,9 @@ class DefinitionInfoTest(test.TestCase):
return a
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
param = node.body[0].args.args[0]
param = node.args.args[0]
source = fn_body[0].value.args[0]
target = fn_body[0].targets[0]
retval = fn_body[1].value
@ -302,7 +302,7 @@ class DefinitionInfoTest(test.TestCase):
return x # pylint:disable=undefined-variable
node = self._parse_and_analyze(test_fn)
fn_body = node.body[0].body
fn_body = node.body
listcomp_target = fn_body[0].value.args[0].generators[0].target
retval = fn_body[1].value

View File

@ -62,7 +62,7 @@ class TypeInfoResolverTest(test.TestCase):
test_fn,
namespace,
arg_types=None):
node, source = parser.parse_entity(test_fn)
node, source, _ = parser.parse_entity(test_fn)
entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,
@ -87,7 +87,7 @@ class TypeInfoResolverTest(test.TestCase):
return opt
node = self._parse_and_analyze(test_fn, {'training': training})
call_node = node.body[0].body[0].value
call_node = node.body[0].value
self.assertTrue(anno.getanno(call_node, 'is_constructor'))
self.assertEquals(training.GradientDescentOptimizer,
anno.getanno(call_node, 'type'))
@ -101,7 +101,7 @@ class TypeInfoResolverTest(test.TestCase):
return res
node = self._parse_and_analyze(test_fn, {})
call_node = node.body[0].body[0].value
call_node = node.body[0].value
self.assertFalse(anno.hasanno(call_node, 'is_constructor'))
def test_class_members_of_detected_constructor(self):
@ -111,7 +111,7 @@ class TypeInfoResolverTest(test.TestCase):
opt.minimize(0)
node = self._parse_and_analyze(test_fn, {'training': training})
method_call = node.body[0].body[1].value.func
method_call = node.body[1].value.func
self.assertEquals(training.GradientDescentOptimizer.minimize,
anno.getanno(method_call, 'live_val'))
@ -122,12 +122,12 @@ class TypeInfoResolverTest(test.TestCase):
sess.run(x)
node = self._parse_and_analyze(test_fn, {'session': session})
constructor_call = node.body[0].body[0].items[0].context_expr
constructor_call = node.body[0].items[0].context_expr
self.assertEquals(session.Session, anno.getanno(constructor_call, 'type'))
self.assertEquals((session.__name__, 'Session'),
anno.getanno(constructor_call, 'type_fqn'))
method_call = node.body[0].body[0].body[0].value.func
method_call = node.body[0].body[0].value.func
self.assertEquals(session.Session.run, anno.getanno(method_call,
'live_val'))
@ -141,7 +141,7 @@ class TypeInfoResolverTest(test.TestCase):
opt.minimize(0)
node = self._parse_and_analyze(test_fn, {'training': training})
method_call = node.body[0].body[1].value.func
method_call = node.body[1].value.func
self.assertFalse(anno.hasanno(method_call, 'live_val'))
def test_parameter_class_members(self):
@ -150,7 +150,7 @@ class TypeInfoResolverTest(test.TestCase):
opt.minimize(0)
node = self._parse_and_analyze(test_fn, {})
method_call = node.body[0].body[0].value.func
method_call = node.body[0].value.func
self.assertFalse(anno.hasanno(method_call, 'live_val'))
def test_parameter_class_members_with_value_hints(self):
@ -165,7 +165,7 @@ class TypeInfoResolverTest(test.TestCase):
training.GradientDescentOptimizer)
})
method_call = node.body[0].body[0].value.func
method_call = node.body[0].value.func
self.assertEquals(training.GradientDescentOptimizer.minimize,
anno.getanno(method_call, 'live_val'))
@ -179,7 +179,7 @@ class TypeInfoResolverTest(test.TestCase):
foo()
node = self._parse_and_analyze(test_fn, {'bar': bar})
method_call = node.body[0].body[1].value.func
method_call = node.body[1].value.func
self.assertFalse(anno.hasanno(method_call, 'live_val'))
def test_nested_members(self):
@ -189,7 +189,7 @@ class TypeInfoResolverTest(test.TestCase):
foo.bar.baz()
node = self._parse_and_analyze(test_fn, {'training': training})
method_call = node.body[0].body[1].value.func
method_call = node.body[1].value.func
self.assertFalse(anno.hasanno(method_call, 'live_val'))
def test_nested_unpacking(self):
@ -205,7 +205,7 @@ class TypeInfoResolverTest(test.TestCase):
return a, b, c
node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar})
a, b, c = node.body[0].body[1].value.elts
a, b, c = node.body[1].value.elts
self.assertEquals(anno.getanno(a, 'type'), Foo)
self.assertEquals(anno.getanno(b, 'type'), Bar)
self.assertEquals(anno.getanno(c, 'type'), Foo)

View File

@ -68,10 +68,10 @@ class TransformerTest(test.TestCase):
return b, inner_function
return a, TestClass
node, _ = parser.parse_entity(test_function)
node, _, _ = parser.parse_entity(test_function)
node = tr.visit(node)
test_function_node = node.body[0]
test_function_node = node
test_class = test_function_node.body[1]
test_method = test_class.body[0]
inner_function = test_method.body[1]
@ -141,10 +141,10 @@ class TransformerTest(test.TestCase):
while True:
raise '1'
node, _ = parser.parse_entity(test_function)
node, _, _ = parser.parse_entity(test_function)
node = tr.visit(node)
fn_body = node.body[0].body
fn_body = node.body
outer_while_body = fn_body[1].body
self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state')
self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state')
@ -207,10 +207,10 @@ class TransformerTest(test.TestCase):
raise '1'
return 'nor this'
node, _ = parser.parse_entity(test_function)
node, _, _ = parser.parse_entity(test_function)
node = tr.visit(node)
for_node = node.body[0].body[2]
for_node = node.body[2]
while_node = for_node.body[1].orelse[1]
self.assertFalse(anno.hasanno(for_node, 'string'))
@ -238,7 +238,7 @@ class TransformerTest(test.TestCase):
print(a)
return None
node, _ = parser.parse_entity(no_exit)
node, _, _ = parser.parse_entity(no_exit)
with self.assertRaises(AssertionError):
tr.visit(node)
@ -246,7 +246,7 @@ class TransformerTest(test.TestCase):
for _ in a:
print(a)
node, _ = parser.parse_entity(no_entry)
node, _, _ = parser.parse_entity(no_entry)
with self.assertRaises(AssertionError):
tr.visit(node)
@ -272,9 +272,8 @@ class TransformerTest(test.TestCase):
tr = TestTransformer(self._simple_context())
node, _ = parser.parse_entity(test_function)
node, _, _ = parser.parse_entity(test_function)
node = tr.visit(node)
node = node.body[0]
self.assertEqual(len(node.body), 2)
self.assertTrue(isinstance(node.body[0], gast.Assign))
@ -303,9 +302,9 @@ class TransformerTest(test.TestCase):
tr = BrokenTransformer(self._simple_context())
node, _ = parser.parse_entity(test_function)
_, _, all_nodes = parser.parse_entity(test_function)
with self.assertRaises(ValueError) as cm:
node = tr.visit(node)
all_nodes = tr.visit(all_nodes)
obtained_message = str(cm.exception)
expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
self.assertRegexpMatches(obtained_message, expected_message)
@ -334,9 +333,9 @@ class TransformerTest(test.TestCase):
tr = BrokenTransformer(self._simple_context())
node, _ = parser.parse_entity(test_function)
_, _, all_nodes = parser.parse_entity(test_function)
with self.assertRaises(ValueError) as cm:
node = tr.visit(node)
all_nodes = tr.visit(all_nodes)
obtained_message = str(cm.exception)
# The message should reference the exception actually raised, not anything
# from the exception handler.