Resolve the line number offset in a way that's compatible with Python 3.8.

PiperOrigin-RevId: 298103960
Change-Id: I55c4ebdd148870adcaf083162638d10e7073363c
This commit is contained in:
Dan Moldovan 2020-02-29 15:51:04 -08:00 committed by TensorFlower Gardener
parent 07597fe0ce
commit bd15c1679a
3 changed files with 61 additions and 1 deletions

View File

@ -172,7 +172,16 @@ class OriginResolver(gast.NodeVisitor):
self._source_lines = source_lines
self._comments_map = comments_map
self._lineno_offset = context_lineno - root_node.lineno
if (hasattr(root_node, 'decorator_list') and root_node.decorator_list and
hasattr(root_node.decorator_list[0], 'lineno')):
# Typical case: functions. The line number of the first decorator
# is more accurate than the line number of the function itself in
# 3.8+. In earier versions they coincide.
self._lineno_offset = context_lineno - root_node.decorator_list[0].lineno
else:
# Fall back to the line number of the root node.
self._lineno_offset = context_lineno - root_node.lineno
self._col_offset = context_col_offset - root_node.col_offset
self._filepath = filepath

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import textwrap
from tensorflow.python.autograph.pyct import anno
@ -210,6 +211,44 @@ class OriginInfoTest(test.TestCase):
self.assertEqual(ret_origin.source_code_line, ' return self')
self.assertIsNone(ret_origin.comment)
def test_resolve_entity_decorated_function(self):
test_fn = basic_definitions.decorated_function
node, source = parser.parse_entity(
test_fn, inspect_utils.getfutureimports(test_fn))
origin_info.resolve_entity(node, source, test_fn)
# The line numbers below should match those in basic_definitions.py
def_origin = anno.getanno(node, anno.Basic.ORIGIN)
if sys.version_info >= (3, 8):
self.assertEqual(def_origin.loc.lineno, 67)
self.assertEqual(
def_origin.source_code_line, 'def decorated_function(x):')
else:
self.assertEqual(def_origin.loc.lineno, 65)
self.assertEqual(def_origin.source_code_line, '@basic_decorator')
self.assertEqual(def_origin.loc.col_offset, 0)
self.assertIsNone(def_origin.comment)
if_origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
self.assertEqual(if_origin.loc.lineno, 68)
self.assertEqual(if_origin.loc.col_offset, 2)
self.assertEqual(if_origin.source_code_line, ' if x > 0:')
self.assertIsNone(if_origin.comment)
ret1_origin = anno.getanno(node.body[0].body[0], anno.Basic.ORIGIN)
self.assertEqual(ret1_origin.loc.lineno, 69)
self.assertEqual(ret1_origin.loc.col_offset, 4)
self.assertEqual(ret1_origin.source_code_line, ' return 1')
self.assertIsNone(ret1_origin.comment)
ret2_origin = anno.getanno(node.body[1], anno.Basic.ORIGIN)
self.assertEqual(ret2_origin.loc.lineno, 70)
self.assertEqual(ret2_origin.loc.col_offset, 2)
self.assertEqual(ret2_origin.source_code_line, ' return 2')
self.assertIsNone(ret2_origin.comment)
if __name__ == '__main__':
test.main()

View File

@ -56,3 +56,15 @@ def function_with_multiline_call(x):
x,
x + 1,
)
def basic_decorator(f):
return f
@basic_decorator
@basic_decorator
def decorated_function(x):
if x > 0:
return 1
return 2