diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 956e90999c7..1ef0504ecb8 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -840,7 +840,6 @@ py_test( python_version = "PY3", srcs_version = "PY2AND3", tags = [ - "no_oss_py38", #TODO(b/151449908) "no_windows", ], deps = [ diff --git a/tensorflow/python/debug/lib/source_utils_test.py b/tensorflow/python/debug/lib/source_utils_test.py index faf2365fc9c..89964a21ba7 100644 --- a/tensorflow/python/debug/lib/source_utils_test.py +++ b/tensorflow/python/debug/lib/source_utils_test.py @@ -18,7 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import ast import os +import sys import tempfile import zipfile @@ -43,7 +45,41 @@ from tensorflow.python.util import tf_inspect def line_number_above(): - return tf_inspect.stack()[1][2] - 1 + """Get lineno of the AST node immediately above this function's call site. + + It is assumed that there is no empty line(s) between the call site and the + preceding AST node. + + Returns: + The lineno of the preceding AST node, at the same level of the AST. + If the preceding AST spans multiple lines: + - In Python 3.8+, the lineno of the first line is returned. + - In older Python versions, the lineno of the last line is returned. + """ + # https://bugs.python.org/issue12458: In Python 3.8, traceback started + # to return the lineno of the first line of a multi-line continuation block, + # instead of that of the last line. Therefore, in Python 3.8+, we use `ast` to + # get the lineno of the first line. + call_site_lineno = tf_inspect.stack()[1][2] + if sys.version_info < (3, 8): + return call_site_lineno - 1 + else: + with open(__file__, "rb") as f: + source_text = f.read().decode("utf-8") + source_tree = ast.parse(source_text) + prev_node = _find_preceding_ast_node(source_tree, call_site_lineno) + return prev_node.lineno + + +def _find_preceding_ast_node(node, lineno): + """Find the ast node immediately before and not including lineno.""" + for i, child_node in enumerate(node.body): + if child_node.lineno == lineno: + return node.body[i - 1] + if hasattr(child_node, "body"): + found_node = _find_preceding_ast_node(child_node, lineno) + if found_node: + return found_node class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase):