From bd15c1679a0dae41b7cdbc368ff6deb0b6eb386b Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Sat, 29 Feb 2020 15:51:04 -0800 Subject: [PATCH] Resolve the line number offset in a way that's compatible with Python 3.8. PiperOrigin-RevId: 298103960 Change-Id: I55c4ebdd148870adcaf083162638d10e7073363c --- .../python/autograph/pyct/origin_info.py | 11 +++++- .../python/autograph/pyct/origin_info_test.py | 39 +++++++++++++++++++ .../pyct/testing/basic_definitions.py | 12 ++++++ 3 files changed, 61 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py index 32f0462cb9a..2a040adc00e 100644 --- a/tensorflow/python/autograph/pyct/origin_info.py +++ b/tensorflow/python/autograph/pyct/origin_info.py @@ -172,7 +172,16 @@ class OriginResolver(gast.NodeVisitor): self._source_lines = source_lines self._comments_map = comments_map - self._lineno_offset = context_lineno - root_node.lineno + if (hasattr(root_node, 'decorator_list') and root_node.decorator_list and + hasattr(root_node.decorator_list[0], 'lineno')): + # Typical case: functions. The line number of the first decorator + # is more accurate than the line number of the function itself in + # 3.8+. In earier versions they coincide. + self._lineno_offset = context_lineno - root_node.decorator_list[0].lineno + else: + # Fall back to the line number of the root node. + self._lineno_offset = context_lineno - root_node.lineno + self._col_offset = context_col_offset - root_node.col_offset self._filepath = filepath diff --git a/tensorflow/python/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py index 01ded4cc559..823dacfe2ed 100644 --- a/tensorflow/python/autograph/pyct/origin_info_test.py +++ b/tensorflow/python/autograph/pyct/origin_info_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys import textwrap from tensorflow.python.autograph.pyct import anno @@ -210,6 +211,44 @@ class OriginInfoTest(test.TestCase): self.assertEqual(ret_origin.source_code_line, ' return self') self.assertIsNone(ret_origin.comment) + def test_resolve_entity_decorated_function(self): + + test_fn = basic_definitions.decorated_function + node, source = parser.parse_entity( + test_fn, inspect_utils.getfutureimports(test_fn)) + origin_info.resolve_entity(node, source, test_fn) + + # The line numbers below should match those in basic_definitions.py + + def_origin = anno.getanno(node, anno.Basic.ORIGIN) + if sys.version_info >= (3, 8): + self.assertEqual(def_origin.loc.lineno, 67) + self.assertEqual( + def_origin.source_code_line, 'def decorated_function(x):') + else: + self.assertEqual(def_origin.loc.lineno, 65) + self.assertEqual(def_origin.source_code_line, '@basic_decorator') + self.assertEqual(def_origin.loc.col_offset, 0) + self.assertIsNone(def_origin.comment) + + if_origin = anno.getanno(node.body[0], anno.Basic.ORIGIN) + self.assertEqual(if_origin.loc.lineno, 68) + self.assertEqual(if_origin.loc.col_offset, 2) + self.assertEqual(if_origin.source_code_line, ' if x > 0:') + self.assertIsNone(if_origin.comment) + + ret1_origin = anno.getanno(node.body[0].body[0], anno.Basic.ORIGIN) + self.assertEqual(ret1_origin.loc.lineno, 69) + self.assertEqual(ret1_origin.loc.col_offset, 4) + self.assertEqual(ret1_origin.source_code_line, ' return 1') + self.assertIsNone(ret1_origin.comment) + + ret2_origin = anno.getanno(node.body[1], anno.Basic.ORIGIN) + self.assertEqual(ret2_origin.loc.lineno, 70) + self.assertEqual(ret2_origin.loc.col_offset, 2) + self.assertEqual(ret2_origin.source_code_line, ' return 2') + self.assertIsNone(ret2_origin.comment) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/pyct/testing/basic_definitions.py b/tensorflow/python/autograph/pyct/testing/basic_definitions.py index 3b4253e3312..ee824a4eaad 100644 --- a/tensorflow/python/autograph/pyct/testing/basic_definitions.py +++ b/tensorflow/python/autograph/pyct/testing/basic_definitions.py @@ -56,3 +56,15 @@ def function_with_multiline_call(x): x, x + 1, ) + + +def basic_decorator(f): + return f + + +@basic_decorator +@basic_decorator +def decorated_function(x): + if x > 0: + return 1 + return 2