diff --git a/tensorflow/python/autograph/converters/directives_test.py b/tensorflow/python/autograph/converters/directives_test.py index 570fb8e379b..870a491ccdf 100644 --- a/tensorflow/python/autograph/converters/directives_test.py +++ b/tensorflow/python/autograph/converters/directives_test.py @@ -84,9 +84,9 @@ class DirectivesTest(converter_testing.TestCase): def call_invalid_directive(): invalid_directive(1) - node, _ = parser.parse_entity(call_invalid_directive) + node, _, _ = parser.parse_entity(call_invalid_directive) # Find the call to the invalid directive - node = node.body[0].body[0].value + node = node.body[0].value with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'): directives_converter._map_args(node, invalid_directive) diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 81b4b9f366f..e2d95b89095 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -122,8 +122,7 @@ class TestCase(test.TestCase): def prepare(self, test_fn, namespace, arg_types=None, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions - node, source = parser.parse_entity(test_fn) - node = node.body[0] + node, source, _ = parser.parse_entity(test_fn) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index a9913ef2a68..bb9464c3361 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -335,9 +335,8 @@ def _add_self_references(namespace, autograph_module): def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True): """Specialization of `entity_to_graph` for callable functions.""" - node, source = parser.parse_entity(f) + node, source, _ = parser.parse_entity(f) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) - node = node.body[0] # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number diff --git a/tensorflow/python/autograph/pyct/cfg_test.py b/tensorflow/python/autograph/pyct/cfg_test.py index d5870124bce..8fb66ca7a76 100644 --- a/tensorflow/python/autograph/pyct/cfg_test.py +++ b/tensorflow/python/autograph/pyct/cfg_test.py @@ -40,7 +40,7 @@ class CountingVisitor(cfg.GraphVisitor): class GraphVisitorTest(test.TestCase): def _build_cfg(self, fn): - node, _ = parser.parse_entity(fn) + node, _, _ = parser.parse_entity(fn) cfgs = cfg.build(node) return cfgs, node @@ -57,15 +57,14 @@ class GraphVisitorTest(test.TestCase): graph, = graphs.values() visitor = CountingVisitor(graph) visitor.visit_forward() - fn_node = node.body[0] - self.assertEqual(visitor.counts[fn_node.args], 1) - self.assertEqual(visitor.counts[fn_node.body[0].test], 1) - self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1) - self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1) + self.assertEqual(visitor.counts[node.args], 1) + self.assertEqual(visitor.counts[node.body[0].test], 1) + self.assertEqual(visitor.counts[node.body[0].body[0]], 1) + self.assertEqual(visitor.counts[node.body[0].body[1]], 1) # The return node should be unreachable in forward direction. - self.assertTrue(fn_node.body[0].body[2] not in visitor.counts) - self.assertEqual(visitor.counts[fn_node.body[1]], 1) + self.assertNotIn(node.body[0].body[2], visitor.counts) + self.assertEqual(visitor.counts[node.body[1]], 1) def test_basic_coverage_reverse(self): @@ -80,20 +79,19 @@ class GraphVisitorTest(test.TestCase): graph, = graphs.values() visitor = CountingVisitor(graph) visitor.visit_reverse() - fn_node = node.body[0] - self.assertEqual(visitor.counts[fn_node.args], 1) - self.assertEqual(visitor.counts[fn_node.body[0].test], 1) - self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1) - self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1) - self.assertTrue(visitor.counts[fn_node.body[0].body[2]], 1) - self.assertEqual(visitor.counts[fn_node.body[1]], 1) + self.assertEqual(visitor.counts[node.args], 1) + self.assertEqual(visitor.counts[node.body[0].test], 1) + self.assertEqual(visitor.counts[node.body[0].body[0]], 1) + self.assertEqual(visitor.counts[node.body[0].body[1]], 1) + self.assertTrue(visitor.counts[node.body[0].body[2]], 1) + self.assertEqual(visitor.counts[node.body[1]], 1) class AstToCfgTest(test.TestCase): def _build_cfg(self, fn): - node, _ = parser.parse_entity(fn) + node, _, _ = parser.parse_entity(fn) cfgs = cfg.build(node) return cfgs diff --git a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py index 5b3bc438570..d7750604778 100644 --- a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py +++ b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py @@ -81,8 +81,8 @@ class AnfTransformerTest(test.TestCase): def test_function(): a = 0 return a - node, _ = parser.parse_entity(test_function) - node = anf.transform(node.body[0], self._simple_context()) + node, _, _ = parser.parse_entity(test_function) + node = anf.transform(node, self._simple_context()) result, _ = compiler.ast_to_object(node) self.assertEqual(test_function(), result.test_function()) @@ -97,15 +97,15 @@ class AnfTransformerTest(test.TestCase): # Testing the code bodies only. Wrapping them in functions so the # syntax highlights nicely, but Python doesn't try to execute the # statements. - exp_node, _ = parser.parse_entity(expected_fn) - node, _ = parser.parse_entity(test_fn) + exp_node, _, _ = parser.parse_entity(expected_fn) + node, _, _ = parser.parse_entity(test_fn) node = anf.transform( node, self._simple_context(), gensym_source=DummyGensym) - exp_name = exp_node.body[0].name + exp_name = exp_node.name # Ignoring the function names in the result because they can't be # the same (because both functions have to exist in the same scope # at the same time). - node.body[0].name = exp_name + node.name = exp_name self.assert_same_ast(exp_node, node) # Check that ANF is idempotent node_repeated = anf.transform( diff --git a/tensorflow/python/autograph/pyct/compiler_test.py b/tensorflow/python/autograph/pyct/compiler_test.py index 6fa289d3cc3..29e8a198fe6 100644 --- a/tensorflow/python/autograph/pyct/compiler_test.py +++ b/tensorflow/python/autograph/pyct/compiler_test.py @@ -39,11 +39,12 @@ class CompilerTest(test.TestCase): b = x + 1 return b + _, _, all_nodes = parser.parse_entity(test_fn) + self.assertEqual( textwrap.dedent(tf_inspect.getsource(test_fn)), tf_inspect.getsource( - compiler.ast_to_object( - parser.parse_entity(test_fn)[0].body[0])[0].test_fn)) + compiler.ast_to_object(all_nodes)[0].test_fn)) def test_ast_to_source(self): node = gast.If( diff --git a/tensorflow/python/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py index 4f25cf4996d..a3dc2f82716 100644 --- a/tensorflow/python/autograph/pyct/origin_info_test.py +++ b/tensorflow/python/autograph/pyct/origin_info_test.py @@ -32,18 +32,17 @@ class OriginInfoTest(test.TestCase): def test_fn(x): return x + 1 - node, _ = parser.parse_entity(test_fn) + node, _, _ = parser.parse_entity(test_fn) fake_origin = origin_info.OriginInfo( loc=origin_info.Location('fake_filename', 3, 7), function_name='fake_function_name', source_code_line='fake source line', comment=None) - fn_node = node.body[-1] - anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin) - converted_code = compiler.ast_to_source(fn_node) + anno.setanno(node.body[0], anno.Basic.ORIGIN, fake_origin) + converted_code = compiler.ast_to_source(node) source_map = origin_info.create_source_map( - fn_node, converted_code, 'test_filename', [0]) + node, converted_code, 'test_filename', [0]) loc = origin_info.LineLocation('test_filename', 2) self.assertIn(loc, source_map) @@ -54,12 +53,11 @@ class OriginInfoTest(test.TestCase): def test_fn(x): return x + 1 - node, _ = parser.parse_entity(test_fn) - fn_node = node.body[-1] - converted_code = compiler.ast_to_source(fn_node) + node, _, _ = parser.parse_entity(test_fn) + converted_code = compiler.ast_to_source(node) source_map = origin_info.create_source_map( - fn_node, converted_code, 'test_filename', [0]) + node, converted_code, 'test_filename', [0]) self.assertEqual(len(source_map), 0) @@ -69,24 +67,23 @@ class OriginInfoTest(test.TestCase): """Docstring.""" return x # comment - node, source = parser.parse_entity(test_fn) - fn_node = node.body[0] + node, source, _ = parser.parse_entity(test_fn) - origin_info.resolve(fn_node, source) + origin_info.resolve(node, source) - origin = anno.getanno(fn_node, anno.Basic.ORIGIN) + origin = anno.getanno(node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 1) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) - origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) + origin = anno.getanno(node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) - origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) + origin = anno.getanno(node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') @@ -99,24 +96,23 @@ class OriginInfoTest(test.TestCase): print(x) return x # comment - node, source = parser.parse_entity(test_fn) - fn_node = node.body[-1] + node, source, _ = parser.parse_entity(test_fn) - origin_info.resolve(fn_node, source) + origin_info.resolve(node, source) - origin = anno.getanno(fn_node, anno.Basic.ORIGIN) + origin = anno.getanno(node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) - origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) + origin = anno.getanno(node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) - origin = anno.getanno(fn_node.body[2], anno.Basic.ORIGIN) + origin = anno.getanno(node.body[2], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 5) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') diff --git a/tensorflow/python/autograph/pyct/parser.py b/tensorflow/python/autograph/pyct/parser.py index f6b2a7863bd..d6f51741809 100644 --- a/tensorflow/python/autograph/pyct/parser.py +++ b/tensorflow/python/autograph/pyct/parser.py @@ -35,7 +35,17 @@ _parse_lock = threading.Lock() # Prevents linecache concurrency errors. def parse_entity(entity): - """Returns the AST of given entity.""" + """Returns the AST and source code of given entity. + + Args: + entity: A python function/method/class + + Returns: + gast.AST, str, gast.ModuleNode: a tuple of the AST node corresponding + exactly to the entity; the string that was parsed to generate the AST; and + the containing module AST node, which might contain extras like future + import nodes. + """ try: with _parse_lock: source = tf_inspect.getsource_no_unwrap(entity) @@ -59,7 +69,9 @@ def parse_entity(entity): source = textwrap.dedent(source) try: - return parse_str(source), source + module_node = parse_str(source) + assert len(module_node.body) == 1 + return module_node.body[0], source, module_node except IndentationError: # The text below lists the causes of this error known to us. There may @@ -99,7 +111,8 @@ def parse_entity(entity): new_source = '\n'.join(lines) try: - return parse_str(new_source), new_source + module_node = parse_str(new_source) + return module_node.body[0], new_source, module_node except SyntaxError as e: raise_parse_failure( 'If this is a lambda function, the error may be avoided by creating' diff --git a/tensorflow/python/autograph/pyct/parser_test.py b/tensorflow/python/autograph/pyct/parser_test.py index e7fa3c7aeb5..ee3e2808259 100644 --- a/tensorflow/python/autograph/pyct/parser_test.py +++ b/tensorflow/python/autograph/pyct/parser_test.py @@ -31,8 +31,8 @@ class ParserTest(test.TestCase): def f(x): return x + 1 - mod, _ = parser.parse_entity(f) - self.assertEqual('f', mod.body[0].name) + node, _, _ = parser.parse_entity(f) + self.assertEqual('f', node.name) def test_parse_str(self): mod = parser.parse_str( diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py index 0dddb444d50..ef3390e03fa 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py @@ -112,7 +112,7 @@ class ScopeTest(test.TestCase): class ActivityAnalyzerTest(test.TestCase): def _parse_and_analyze(self, test_fn): - node, source = parser.parse_entity(test_fn) + node, source, _ = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, @@ -149,7 +149,7 @@ class ActivityAnalyzerTest(test.TestCase): return c node, _ = self._parse_and_analyze(test_fn) - print_node = node.body[0].body[2] + print_node = node.body[2] if isinstance(print_node, gast.Print): # Python 2 print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE) @@ -172,7 +172,7 @@ class ActivityAnalyzerTest(test.TestCase): return c node, _ = self._parse_and_analyze(test_fn) - call_node = node.body[0].body[2].value + call_node = node.body[2].value # We basically need to detect which variables are captured by the call # arguments. self.assertScopeIs( @@ -189,7 +189,7 @@ class ActivityAnalyzerTest(test.TestCase): return a.d node, _ = self._parse_and_analyze(test_fn) - call_node = node.body[0].body[1].value + call_node = node.body[1].value self.assertScopeIs( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), ()) @@ -205,7 +205,7 @@ class ActivityAnalyzerTest(test.TestCase): return a[c] node, _ = self._parse_and_analyze(test_fn) - call_node = node.body[0].body[2].value + call_node = node.body[2].value self.assertScopeIs( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a[0]', 'a[b]', 'b'), ()) @@ -220,7 +220,7 @@ class ActivityAnalyzerTest(test.TestCase): return b, c node, _ = self._parse_and_analyze(test_fn) - while_node = node.body[0].body[1] + while_node = node.body[1] self.assertScopeIs( anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c')) self.assertScopeIs( @@ -239,7 +239,7 @@ class ActivityAnalyzerTest(test.TestCase): return b, c node, _ = self._parse_and_analyze(test_fn) - for_node = node.body[0].body[1] + for_node = node.body[1] self.assertScopeIs( anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c')) self.assertScopeIs( @@ -260,7 +260,7 @@ class ActivityAnalyzerTest(test.TestCase): return z, u node, _ = self._parse_and_analyze(test_fn) - if_node = node.body[0].body[0] + if_node = node.body[0] self.assertScopeIs( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z')) self.assertScopeIs( @@ -285,7 +285,7 @@ class ActivityAnalyzerTest(test.TestCase): return d node, _ = self._parse_and_analyze(test_fn) - if_node = node.body[0].body[0] + if_node = node.body[0] self.assertScopeIs( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'a.c'), ('a.b', 'd')) self.assertScopeIs( @@ -307,7 +307,7 @@ class ActivityAnalyzerTest(test.TestCase): return d node, _ = self._parse_and_analyze(test_fn) - if_node = node.body[0].body[0] + if_node = node.body[0] self.assertScopeIs( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'b', 'c', 'a[c]'), ('a[b]', 'd')) @@ -329,7 +329,7 @@ class ActivityAnalyzerTest(test.TestCase): return a node, _ = self._parse_and_analyze(test_fn) - inner_if_node = node.body[0].body[0].body[0] + inner_if_node = node.body[0].body[0] self.assertScopeIs( anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',)) self.assertScopeIs( @@ -350,7 +350,7 @@ class ActivityAnalyzerTest(test.TestCase): return b, c node, _ = self._parse_and_analyze(test_fn) - fn_def_node = node.body[0].body[0] + fn_def_node = node.body[0] self.assertScopeIs( anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',)) @@ -364,7 +364,7 @@ class ActivityAnalyzerTest(test.TestCase): self.b.c = 1 node, _ = self._parse_and_analyze(TestClass) - init_node = node.body[0].body[0] + init_node = node.body[0] self.assertScopeIs( anno.getanno(init_node, NodeAnno.BODY_SCOPE), ('self', 'a', 'self.b'), ('self', 'self.b', 'self.b.c')) @@ -375,7 +375,7 @@ class ActivityAnalyzerTest(test.TestCase): a[0] += 1 node, _ = self._parse_and_analyze(test_fn) - fn_node = node.body[0] + fn_node = node self.assertScopeIs( anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('a', 'a[0]'), ('a[0]',)) @@ -385,7 +385,7 @@ class ActivityAnalyzerTest(test.TestCase): return c node, _ = self._parse_and_analyze(test_fn) - fn_node = node.body[0] + fn_node = node self.assertScopeIs(anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('c',), ()) def test_aug_assign(self): @@ -394,7 +394,7 @@ class ActivityAnalyzerTest(test.TestCase): a += b node, _ = self._parse_and_analyze(test_fn) - fn_node = node.body[0] + fn_node = node self.assertScopeIs( anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('a', 'b'), ('a')) @@ -409,7 +409,7 @@ class ActivityAnalyzerTest(test.TestCase): foo()['bar'] += x node, _ = self._parse_and_analyze(test_fn) - fn_node = node.body[0] + fn_node = node self.assertScopeIs( anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('foo', 'x'), ()) @@ -419,7 +419,7 @@ class ActivityAnalyzerTest(test.TestCase): return b node, _ = self._parse_and_analyze(test_fn) - fn_node = node.body[0] + fn_node = node body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) self.assertScopeIs(body_scope, ('b',), ()) self.assertScopeIs(body_scope.parent, ('b',), ('a', 'b')) @@ -433,7 +433,7 @@ class ActivityAnalyzerTest(test.TestCase): return lambda: a + b node, _ = self._parse_and_analyze(test_fn) - fn_node = node.body[0] + fn_node = node body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) self.assertScopeIs(body_scope, ('a', 'b'), ()) # Nothing local to the lambda is tracked. @@ -445,7 +445,7 @@ class ActivityAnalyzerTest(test.TestCase): return lambda a: a + b node, _ = self._parse_and_analyze(test_fn) - fn_node = node.body[0] + fn_node = node body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) self.assertScopeIs(body_scope, ('b',), ()) self.assertSymbolSetsAre((), body_scope.params.keys(), 'params') @@ -456,7 +456,7 @@ class ActivityAnalyzerTest(test.TestCase): a = (lambda a, b, c: a + b + c)(d, 1, 2) + b node, _ = self._parse_and_analyze(test_fn) - fn_node = node.body[0] + fn_node = node body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) self.assertScopeIs(body_scope, ('b', 'd'), ('a',)) self.assertSymbolSetsAre((), body_scope.params.keys(), 'params') @@ -467,7 +467,7 @@ class ActivityAnalyzerTest(test.TestCase): a = lambda a, b: d(lambda b: a + b + c) # pylint: disable=undefined-variable node, _ = self._parse_and_analyze(test_fn) - fn_node = node.body[0] + fn_node = node body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) self.assertScopeIs(body_scope, ('c', 'd'), ('a',)) self.assertSymbolSetsAre((), body_scope.params.keys(), 'params') diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py index f8ae3d6eecf..14bb3682e3b 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py @@ -41,7 +41,7 @@ class LiveValuesResolverTest(test.TestCase): literals=None, arg_types=None): literals = literals or {} - node, source = parser.parse_entity(test_fn) + node, source, _ = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, @@ -67,7 +67,7 @@ class LiveValuesResolverTest(test.TestCase): return a node = self._parse_and_analyze(test_fn, {}, literals={'a': 'bar'}) - retval_node = node.body[0].body[0].value + retval_node = node.body[0].value self.assertEquals('bar', anno.getanno(retval_node, 'live_val')) def test_primitive_values(self): @@ -78,7 +78,7 @@ class LiveValuesResolverTest(test.TestCase): return a node = self._parse_and_analyze(test_fn, {'a': True}) - retval_node = node.body[0].body[0].value + retval_node = node.body[0].value if six.PY2: self.assertEqual( anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool')) @@ -94,7 +94,7 @@ class LiveValuesResolverTest(test.TestCase): return foo() node = self._parse_and_analyze(test_fn, {'foo': foo}) - func_node = node.body[0].body[0].value.func + func_node = node.body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo',), anno.getanno(func_node, 'fqn')) @@ -104,7 +104,7 @@ class LiveValuesResolverTest(test.TestCase): return constant_op.constant(0) node = self._parse_and_analyze(test_fn, {'constant_op': constant_op}) - func_node = node.body[0].body[0].value.func + func_node = node.body[0].value.func self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) self.assertEquals((constant_op.__name__, 'constant'), anno.getanno(func_node, 'fqn')) @@ -122,7 +122,7 @@ class LiveValuesResolverTest(test.TestCase): node = self._parse_and_analyze( TestClass.test_fn, {'constant_op': constant_op}, arg_types={'self': (TestClass.__name__, TestClass)}) - func_node = node.body[0].body[0].value.func + func_node = node.body[0].value.func self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val')) self.assertEquals(TestClass, anno.getanno(func_node, 'parent_type')) self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn')) diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py index 904386bef4c..c32abb9efd1 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py @@ -33,7 +33,7 @@ from tensorflow.python.platform import test class LivenessTest(test.TestCase): def _parse_and_analyze(self, test_fn): - node, source = parser.parse_entity(test_fn) + node, source, _ = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, @@ -75,7 +75,7 @@ class LivenessTest(test.TestCase): return x node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasLiveOut(fn_body[0], ('a', 'x')) self.assertHasLiveOut(fn_body[1], 'x') @@ -92,7 +92,7 @@ class LivenessTest(test.TestCase): return x node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasLiveOut(fn_body[0], 'a') self.assertHasLiveOut(fn_body[1], 'x') @@ -105,7 +105,7 @@ class LivenessTest(test.TestCase): return x node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasLiveOut(fn_body[0], 'x') @@ -117,7 +117,7 @@ class LivenessTest(test.TestCase): return x.y node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasLiveOut(fn_body[0], ('x.y', 'x')) @@ -133,7 +133,7 @@ class LivenessTest(test.TestCase): foo() node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasLiveOut(fn_body[0], 'a') @@ -151,7 +151,7 @@ class LivenessTest(test.TestCase): child() node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasLiveOut(fn_body[0], 'max') @@ -165,7 +165,7 @@ class LivenessTest(test.TestCase): y = 0 node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasLiveOut(fn_body[0], ()) @@ -179,7 +179,7 @@ class LivenessTest(test.TestCase): return x node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasLiveIn(fn_body[0], ('a', 'b', 'c', 'x')) self.assertHasLiveIn(fn_body[1], ('c', 'x')) @@ -196,7 +196,7 @@ class LivenessTest(test.TestCase): return x node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasLiveIn(fn_body[0], ('a', 'b', 'c', 'd')) self.assertHasLiveIn(fn_body[1], ('d', 'x')) @@ -211,7 +211,7 @@ class LivenessTest(test.TestCase): return y, z node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasLiveIn(fn_body[0], ('a', 'y', 'z')) @@ -226,7 +226,7 @@ class LivenessTest(test.TestCase): return y, z node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasLiveIn(fn_body[0], ('a', 'y', 'z')) @@ -240,7 +240,7 @@ class LivenessTest(test.TestCase): y = 0 node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasLiveIn(fn_body[0], ('a', 'x', 'y')) @@ -251,7 +251,7 @@ class LivenessTest(test.TestCase): return node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body if six.PY2: self.assertHasLiveIn(fn_body[0], ('all', 'x', 'y')) @@ -265,7 +265,7 @@ class LivenessTest(test.TestCase): return node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body if six.PY2: self.assertHasLiveIn(fn_body[0], ('x', 'y')) @@ -279,7 +279,7 @@ class LivenessTest(test.TestCase): return node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body if six.PY2: self.assertHasLiveIn(fn_body[0], ('x', 'y')) @@ -293,7 +293,7 @@ class LivenessTest(test.TestCase): return node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body if six.PY2: self.assertHasLiveIn(fn_body[0], ('k', 'v', 'y')) diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py index 3fb936460da..3359886f50d 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py @@ -33,7 +33,7 @@ from tensorflow.python.platform import test class DefinitionInfoTest(test.TestCase): def _parse_and_analyze(self, test_fn): - node, source = parser.parse_entity(test_fn) + node, source, _ = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, @@ -86,7 +86,7 @@ class DefinitionInfoTest(test.TestCase): return a node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasDefs(fn_body[0].targets[0], 1) self.assertHasDefs(fn_body[1].test, 1) @@ -105,7 +105,7 @@ class DefinitionInfoTest(test.TestCase): return a node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasDefs(fn_body[0].value.args[0], 1) self.assertHasDefs(fn_body[1].body[0].targets[0], 1) @@ -128,7 +128,7 @@ class DefinitionInfoTest(test.TestCase): return x, y node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasDefs(fn_body[0].targets[0], 1) self.assertHasDefs(fn_body[1].test, 2) @@ -153,7 +153,7 @@ class DefinitionInfoTest(test.TestCase): return x, y node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasDefs(fn_body[0].targets[0], 1) self.assertHasDefs(fn_body[1].target, 1) @@ -178,7 +178,7 @@ class DefinitionInfoTest(test.TestCase): return a node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body def_of_a_in_if = fn_body[1].body[0].targets[0] self.assertHasDefs(fn_body[0].targets[0], 1) @@ -202,7 +202,7 @@ class DefinitionInfoTest(test.TestCase): return a node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body parent_return = fn_body[3] child_return = fn_body[1].body[1] @@ -219,7 +219,7 @@ class DefinitionInfoTest(test.TestCase): return a node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body self.assertHasDefs(fn_body[0].items[0].context_expr.func, 0) self.assertHasDefs(fn_body[0].items[0].context_expr.args[0], 1) @@ -232,7 +232,7 @@ class DefinitionInfoTest(test.TestCase): return l node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body creation = fn_body[0].targets[0] mutation = fn_body[1].targets[0].value @@ -251,7 +251,7 @@ class DefinitionInfoTest(test.TestCase): return a node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body first_def = fn_body[0].targets[0] second_def = fn_body[1].orelse[0].targets[0] @@ -270,7 +270,7 @@ class DefinitionInfoTest(test.TestCase): return a node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body use = fn_body[2].value self.assertHasDefs(use, 0) @@ -285,9 +285,9 @@ class DefinitionInfoTest(test.TestCase): return a node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body - param = node.body[0].args.args[0] + param = node.args.args[0] source = fn_body[0].value.args[0] target = fn_body[0].targets[0] retval = fn_body[1].value @@ -302,7 +302,7 @@ class DefinitionInfoTest(test.TestCase): return x # pylint:disable=undefined-variable node = self._parse_and_analyze(test_fn) - fn_body = node.body[0].body + fn_body = node.body listcomp_target = fn_body[0].value.args[0].generators[0].target retval = fn_body[1].value diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py index 2263667a9ac..42e52a6b3b9 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py @@ -62,7 +62,7 @@ class TypeInfoResolverTest(test.TestCase): test_fn, namespace, arg_types=None): - node, source = parser.parse_entity(test_fn) + node, source, _ = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo( source_code=source, source_file=None, @@ -87,7 +87,7 @@ class TypeInfoResolverTest(test.TestCase): return opt node = self._parse_and_analyze(test_fn, {'training': training}) - call_node = node.body[0].body[0].value + call_node = node.body[0].value self.assertTrue(anno.getanno(call_node, 'is_constructor')) self.assertEquals(training.GradientDescentOptimizer, anno.getanno(call_node, 'type')) @@ -101,7 +101,7 @@ class TypeInfoResolverTest(test.TestCase): return res node = self._parse_and_analyze(test_fn, {}) - call_node = node.body[0].body[0].value + call_node = node.body[0].value self.assertFalse(anno.hasanno(call_node, 'is_constructor')) def test_class_members_of_detected_constructor(self): @@ -111,7 +111,7 @@ class TypeInfoResolverTest(test.TestCase): opt.minimize(0) node = self._parse_and_analyze(test_fn, {'training': training}) - method_call = node.body[0].body[1].value.func + method_call = node.body[1].value.func self.assertEquals(training.GradientDescentOptimizer.minimize, anno.getanno(method_call, 'live_val')) @@ -122,12 +122,12 @@ class TypeInfoResolverTest(test.TestCase): sess.run(x) node = self._parse_and_analyze(test_fn, {'session': session}) - constructor_call = node.body[0].body[0].items[0].context_expr + constructor_call = node.body[0].items[0].context_expr self.assertEquals(session.Session, anno.getanno(constructor_call, 'type')) self.assertEquals((session.__name__, 'Session'), anno.getanno(constructor_call, 'type_fqn')) - method_call = node.body[0].body[0].body[0].value.func + method_call = node.body[0].body[0].value.func self.assertEquals(session.Session.run, anno.getanno(method_call, 'live_val')) @@ -141,7 +141,7 @@ class TypeInfoResolverTest(test.TestCase): opt.minimize(0) node = self._parse_and_analyze(test_fn, {'training': training}) - method_call = node.body[0].body[1].value.func + method_call = node.body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val')) def test_parameter_class_members(self): @@ -150,7 +150,7 @@ class TypeInfoResolverTest(test.TestCase): opt.minimize(0) node = self._parse_and_analyze(test_fn, {}) - method_call = node.body[0].body[0].value.func + method_call = node.body[0].value.func self.assertFalse(anno.hasanno(method_call, 'live_val')) def test_parameter_class_members_with_value_hints(self): @@ -165,7 +165,7 @@ class TypeInfoResolverTest(test.TestCase): training.GradientDescentOptimizer) }) - method_call = node.body[0].body[0].value.func + method_call = node.body[0].value.func self.assertEquals(training.GradientDescentOptimizer.minimize, anno.getanno(method_call, 'live_val')) @@ -179,7 +179,7 @@ class TypeInfoResolverTest(test.TestCase): foo() node = self._parse_and_analyze(test_fn, {'bar': bar}) - method_call = node.body[0].body[1].value.func + method_call = node.body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val')) def test_nested_members(self): @@ -189,7 +189,7 @@ class TypeInfoResolverTest(test.TestCase): foo.bar.baz() node = self._parse_and_analyze(test_fn, {'training': training}) - method_call = node.body[0].body[1].value.func + method_call = node.body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val')) def test_nested_unpacking(self): @@ -205,7 +205,7 @@ class TypeInfoResolverTest(test.TestCase): return a, b, c node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar}) - a, b, c = node.body[0].body[1].value.elts + a, b, c = node.body[1].value.elts self.assertEquals(anno.getanno(a, 'type'), Foo) self.assertEquals(anno.getanno(b, 'type'), Bar) self.assertEquals(anno.getanno(c, 'type'), Foo) diff --git a/tensorflow/python/autograph/pyct/transformer_test.py b/tensorflow/python/autograph/pyct/transformer_test.py index 9d83653ad34..bd19ebad5c5 100644 --- a/tensorflow/python/autograph/pyct/transformer_test.py +++ b/tensorflow/python/autograph/pyct/transformer_test.py @@ -68,10 +68,10 @@ class TransformerTest(test.TestCase): return b, inner_function return a, TestClass - node, _ = parser.parse_entity(test_function) + node, _, _ = parser.parse_entity(test_function) node = tr.visit(node) - test_function_node = node.body[0] + test_function_node = node test_class = test_function_node.body[1] test_method = test_class.body[0] inner_function = test_method.body[1] @@ -141,10 +141,10 @@ class TransformerTest(test.TestCase): while True: raise '1' - node, _ = parser.parse_entity(test_function) + node, _, _ = parser.parse_entity(test_function) node = tr.visit(node) - fn_body = node.body[0].body + fn_body = node.body outer_while_body = fn_body[1].body self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state') self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state') @@ -207,10 +207,10 @@ class TransformerTest(test.TestCase): raise '1' return 'nor this' - node, _ = parser.parse_entity(test_function) + node, _, _ = parser.parse_entity(test_function) node = tr.visit(node) - for_node = node.body[0].body[2] + for_node = node.body[2] while_node = for_node.body[1].orelse[1] self.assertFalse(anno.hasanno(for_node, 'string')) @@ -238,7 +238,7 @@ class TransformerTest(test.TestCase): print(a) return None - node, _ = parser.parse_entity(no_exit) + node, _, _ = parser.parse_entity(no_exit) with self.assertRaises(AssertionError): tr.visit(node) @@ -246,7 +246,7 @@ class TransformerTest(test.TestCase): for _ in a: print(a) - node, _ = parser.parse_entity(no_entry) + node, _, _ = parser.parse_entity(no_entry) with self.assertRaises(AssertionError): tr.visit(node) @@ -272,9 +272,8 @@ class TransformerTest(test.TestCase): tr = TestTransformer(self._simple_context()) - node, _ = parser.parse_entity(test_function) + node, _, _ = parser.parse_entity(test_function) node = tr.visit(node) - node = node.body[0] self.assertEqual(len(node.body), 2) self.assertTrue(isinstance(node.body[0], gast.Assign)) @@ -303,9 +302,9 @@ class TransformerTest(test.TestCase): tr = BrokenTransformer(self._simple_context()) - node, _ = parser.parse_entity(test_function) + _, _, all_nodes = parser.parse_entity(test_function) with self.assertRaises(ValueError) as cm: - node = tr.visit(node) + all_nodes = tr.visit(all_nodes) obtained_message = str(cm.exception) expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"' self.assertRegexpMatches(obtained_message, expected_message) @@ -334,9 +333,9 @@ class TransformerTest(test.TestCase): tr = BrokenTransformer(self._simple_context()) - node, _ = parser.parse_entity(test_function) + _, _, all_nodes = parser.parse_entity(test_function) with self.assertRaises(ValueError) as cm: - node = tr.visit(node) + all_nodes = tr.visit(all_nodes) obtained_message = str(cm.exception) # The message should reference the exception actually raised, not anything # from the exception handler.