From 895a7667884545a68480eb91916a5a23c2852308 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 13 Jul 2018 04:51:48 -0700 Subject: [PATCH] Add initial support for interpolating filename and line number in error messages returned from C++. PiperOrigin-RevId: 204455158 --- tensorflow/python/BUILD | 5 +- .../python/framework/error_interpolation.py | 82 +++++++++++++- .../framework/error_interpolation_test.py | 104 ++++++++++++++++-- tensorflow/python/util/tf_stack.py | 6 + 4 files changed, 182 insertions(+), 15 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 924db54cbcf..2fba3c2acb6 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -705,7 +705,9 @@ py_library( "framework/error_interpolation.py", ], srcs_version = "PY2AND3", - deps = [], + deps = [ + ":util", + ], ) py_library( @@ -1040,6 +1042,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":client_testlib", + ":constant_op", ":error_interpolation", ], ) diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py index 9ccae761471..519e0fda0a4 100644 --- a/tensorflow/python/framework/error_interpolation.py +++ b/tensorflow/python/framework/error_interpolation.py @@ -29,6 +29,9 @@ import string import six +from tensorflow.python.util import tf_stack + + _NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?" _FORMAT_REGEX = r"[A-Za-z0-9_.\-/${}:]+" _TAG_REGEX = r"\^\^({name}):({name}):({fmt})\^\^".format( @@ -38,6 +41,8 @@ _INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX) _ParseTag = collections.namedtuple("_ParseTag", ["type", "name", "format"]) +_BAD_FILE_SUBSTRINGS = ["tensorflow/python", " + file: Replaced with the filename in which the node was defined. + line: Replaced by the line number at which the node was defined. + Args: message: String to parse @@ -72,9 +81,47 @@ def _parse_message(message): return seps, tags -# TODO(jtkeeling): Modify to actually interpolate format strings rather than -# echoing them. -def interpolate(error_message): +def _get_field_dict_from_traceback(tf_traceback, frame_index): + """Convert traceback elements into interpolation dictionary and return.""" + frame = tf_traceback[frame_index] + return { + "file": frame[tf_stack.TB_FILENAME], + "line": frame[tf_stack.TB_LINENO], + } + + +def _find_index_of_defining_frame_for_op(op): + """Return index in op._traceback with first 'useful' frame. + + This method reads through the stack stored in op._traceback looking for the + innermost frame which (hopefully) belongs to the caller. It accomplishes this + by rejecting frames whose filename appears to come from TensorFlow (see + error_interpolation._BAD_FILE_SUBSTRINGS for the list of rejected substrings). + + Args: + op: the Operation object for which we would like to find the defining + location. + + Returns: + Integer index into op._traceback where the first non-TF file was found + (innermost to outermost), or 0 (for the outermost stack frame) if all files + came from TensorFlow. + """ + # pylint: disable=protected-access + # Index 0 of tf_traceback is the outermost frame. + tf_traceback = tf_stack.convert_stack(op._traceback) + size = len(tf_traceback) + # pylint: enable=protected-access + filenames = [frame[tf_stack.TB_FILENAME] for frame in tf_traceback] + # We process the filenames from the innermost frame to outermost. + for idx, filename in enumerate(reversed(filenames)): + contains_bad_substrings = [ss in filename for ss in _BAD_FILE_SUBSTRINGS] + if not any(contains_bad_substrings): + return size - idx - 1 + return 0 + + +def interpolate(error_message, graph): """Interpolates an error message. The error message can contain tags of the form ^^type:name:format^^ which will @@ -82,11 +129,38 @@ def interpolate(error_message): Args: error_message: A string to interpolate. + graph: ops.Graph object containing all nodes referenced in the error + message. Returns: The string with tags of the form ^^type:name:format^^ interpolated. """ seps, tags = _parse_message(error_message) - subs = [string.Template(tag.format).safe_substitute({}) for tag in tags] + + node_name_to_substitution_dict = {} + for name in [t.name for t in tags]: + try: + op = graph.get_operation_by_name(name) + except KeyError: + op = None + + if op: + frame_index = _find_index_of_defining_frame_for_op(op) + # pylint: disable=protected-access + field_dict = _get_field_dict_from_traceback(op._traceback, frame_index) + # pylint: enable=protected-access + else: + field_dict = { + "file": "", + "line": "", + "func": "", + "code": None, + } + node_name_to_substitution_dict[name] = field_dict + + subs = [ + string.Template(tag.format).safe_substitute( + node_name_to_substitution_dict[tag.name]) for tag in tags + ] return "".join( itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue=""))) diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py index ad448deb622..091f0da2a2d 100644 --- a/tensorflow/python/framework/error_interpolation_test.py +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -18,31 +18,115 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import error_interpolation from tensorflow.python.platform import test +from tensorflow.python.util import tf_stack + + +def _make_frame_with_filename(op, idx, filename): + """Return a copy of an existing stack frame with a new filename.""" + stack_frame = list(op._traceback[idx]) + stack_frame[tf_stack.TB_FILENAME] = filename + return tuple(stack_frame) + + +def _modify_op_stack_with_filenames(op, num_user_frames, user_filename, + num_inner_tf_frames): + """Replace op._traceback with a new traceback using special filenames.""" + tf_filename = "%d" + error_interpolation._BAD_FILE_SUBSTRINGS[0] + user_filename = "%d/my_favorite_file.py" + + num_requested_frames = num_user_frames + num_inner_tf_frames + num_actual_frames = len(op._traceback) + num_outer_frames = num_actual_frames - num_requested_frames + assert num_requested_frames <= num_actual_frames, "Too few real frames." + + # The op's traceback has outermost frame at index 0. + stack = [] + for idx in range(0, num_outer_frames): + stack.append(op._traceback[idx]) + for idx in range(len(stack), len(stack)+num_user_frames): + stack.append(_make_frame_with_filename(op, idx, user_filename % idx)) + for idx in range(len(stack), len(stack)+num_inner_tf_frames): + stack.append(_make_frame_with_filename(op, idx, tf_filename % idx)) + op._traceback = stack class InterpolateTest(test.TestCase): + def setUp(self): + # Add nodes to the graph for retrieval by name later. + constant_op.constant(1, name="One") + constant_op.constant(2, name="Two") + three = constant_op.constant(3, name="Three") + self.graph = three.graph + + # Change the list of bad file substrings so that constant_op.py is chosen + # as the defining stack frame for constant_op.constant ops. + self.old_bad_strings = error_interpolation._BAD_FILE_SUBSTRINGS + error_interpolation._BAD_FILE_SUBSTRINGS = ["/ops.py", "/util"] + + def tearDown(self): + error_interpolation._BAD_FILE_SUBSTRINGS = self.old_bad_strings + + def testFindIndexOfDefiningFrameForOp(self): + local_op = constant_op.constant(42).op + user_filename = "hope.py" + _modify_op_stack_with_filenames(local_op, + num_user_frames=3, + user_filename=user_filename, + num_inner_tf_frames=5) + idx = error_interpolation._find_index_of_defining_frame_for_op(local_op) + # Expected frame is 6th from the end because there are 5 inner frames witih + # TF filenames. + expected_frame = len(local_op._traceback) - 6 + self.assertEqual(expected_frame, idx) + + def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self): + local_op = constant_op.constant(43).op + # Truncate stack to known length. + local_op._traceback = local_op._traceback[:7] + # Ensure all frames look like TF frames. + _modify_op_stack_with_filenames(local_op, + num_user_frames=0, + user_filename="user_file.py", + num_inner_tf_frames=7) + idx = error_interpolation._find_index_of_defining_frame_for_op(local_op) + self.assertEqual(0, idx) + def testNothingToDo(self): normal_string = "This is just a normal string" - interpolated_string = error_interpolation.interpolate(normal_string) + interpolated_string = error_interpolation.interpolate(normal_string, + self.graph) self.assertEqual(interpolated_string, normal_string) def testOneTag(self): - one_tag_string = "^^node:Foo:${file}^^" - interpolated_string = error_interpolation.interpolate(one_tag_string) - self.assertEqual(interpolated_string, "${file}") + one_tag_string = "^^node:Two:${file}^^" + interpolated_string = error_interpolation.interpolate(one_tag_string, + self.graph) + self.assertTrue(interpolated_string.endswith("constant_op.py"), + "interpolated_string '%s' did not end with constant_op.py" + % interpolated_string) + + def testOneTagWithAFakeNameResultsInPlaceholders(self): + one_tag_string = "^^node:MinusOne:${file}^^" + interpolated_string = error_interpolation.interpolate(one_tag_string, + self.graph) + self.assertEqual(interpolated_string, "") def testTwoTagsNoSeps(self): - two_tags_no_seps = "^^node:Foo:${file}^^^^node:Bar:${line}^^" - interpolated_string = error_interpolation.interpolate(two_tags_no_seps) - self.assertEqual(interpolated_string, "${file}${line}") + two_tags_no_seps = "^^node:One:${file}^^^^node:Three:${line}^^" + interpolated_string = error_interpolation.interpolate(two_tags_no_seps, + self.graph) + self.assertRegexpMatches(interpolated_string, "constant_op.py[0-9]+") def testTwoTagsWithSeps(self): - two_tags_with_seps = "123^^node:Foo:${file}^^456^^node:Bar:${line}^^789" - interpolated_string = error_interpolation.interpolate(two_tags_with_seps) - self.assertEqual(interpolated_string, "123${file}456${line}789") + two_tags_with_seps = ";;;^^node:Two:${file}^^,,,^^node:Three:${line}^^;;;" + interpolated_string = error_interpolation.interpolate(two_tags_with_seps, + self.graph) + expected_regex = "^;;;.*constant_op.py,,,[0-9]*;;;$" + self.assertRegexpMatches(interpolated_string, expected_regex) if __name__ == "__main__": diff --git a/tensorflow/python/util/tf_stack.py b/tensorflow/python/util/tf_stack.py index dacc1ce83e1..fe4f4a63eb5 100644 --- a/tensorflow/python/util/tf_stack.py +++ b/tensorflow/python/util/tf_stack.py @@ -21,6 +21,12 @@ from __future__ import print_function import linecache import sys +# Names for indices into TF traceback tuples. +TB_FILENAME = 0 +TB_LINENO = 1 +TB_FUNCNAME = 2 +TB_CODEDICT = 3 # Dictionary of Python interpreter state. + def extract_stack(extract_frame_info_fn=None): """A lightweight, extensible re-implementation of traceback.extract_stack.