Add initial support for interpolating filename and line number in error messages returned from C++.
PiperOrigin-RevId: 204455158
This commit is contained in:
parent
e438b192d2
commit
895a766788
@ -705,7 +705,9 @@ py_library(
|
||||
"framework/error_interpolation.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [],
|
||||
deps = [
|
||||
":util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -1040,6 +1042,7 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":client_testlib",
|
||||
":constant_op",
|
||||
":error_interpolation",
|
||||
],
|
||||
)
|
||||
|
@ -29,6 +29,9 @@ import string
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.util import tf_stack
|
||||
|
||||
|
||||
_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?"
|
||||
_FORMAT_REGEX = r"[A-Za-z0-9_.\-/${}:]+"
|
||||
_TAG_REGEX = r"\^\^({name}):({name}):({fmt})\^\^".format(
|
||||
@ -38,6 +41,8 @@ _INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX)
|
||||
|
||||
_ParseTag = collections.namedtuple("_ParseTag", ["type", "name", "format"])
|
||||
|
||||
_BAD_FILE_SUBSTRINGS = ["tensorflow/python", "<embedded"]
|
||||
|
||||
|
||||
def _parse_message(message):
|
||||
"""Parses the message.
|
||||
@ -48,6 +53,10 @@ def _parse_message(message):
|
||||
"123^^node:Foo:${file}^^456^^node:Bar:${line}^^789", there are two tags and
|
||||
three separators. The separators are the numeric characters.
|
||||
|
||||
Supported tags after node:<node_name>
|
||||
file: Replaced with the filename in which the node was defined.
|
||||
line: Replaced by the line number at which the node was defined.
|
||||
|
||||
Args:
|
||||
message: String to parse
|
||||
|
||||
@ -72,9 +81,47 @@ def _parse_message(message):
|
||||
return seps, tags
|
||||
|
||||
|
||||
# TODO(jtkeeling): Modify to actually interpolate format strings rather than
|
||||
# echoing them.
|
||||
def interpolate(error_message):
|
||||
def _get_field_dict_from_traceback(tf_traceback, frame_index):
|
||||
"""Convert traceback elements into interpolation dictionary and return."""
|
||||
frame = tf_traceback[frame_index]
|
||||
return {
|
||||
"file": frame[tf_stack.TB_FILENAME],
|
||||
"line": frame[tf_stack.TB_LINENO],
|
||||
}
|
||||
|
||||
|
||||
def _find_index_of_defining_frame_for_op(op):
|
||||
"""Return index in op._traceback with first 'useful' frame.
|
||||
|
||||
This method reads through the stack stored in op._traceback looking for the
|
||||
innermost frame which (hopefully) belongs to the caller. It accomplishes this
|
||||
by rejecting frames whose filename appears to come from TensorFlow (see
|
||||
error_interpolation._BAD_FILE_SUBSTRINGS for the list of rejected substrings).
|
||||
|
||||
Args:
|
||||
op: the Operation object for which we would like to find the defining
|
||||
location.
|
||||
|
||||
Returns:
|
||||
Integer index into op._traceback where the first non-TF file was found
|
||||
(innermost to outermost), or 0 (for the outermost stack frame) if all files
|
||||
came from TensorFlow.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
# Index 0 of tf_traceback is the outermost frame.
|
||||
tf_traceback = tf_stack.convert_stack(op._traceback)
|
||||
size = len(tf_traceback)
|
||||
# pylint: enable=protected-access
|
||||
filenames = [frame[tf_stack.TB_FILENAME] for frame in tf_traceback]
|
||||
# We process the filenames from the innermost frame to outermost.
|
||||
for idx, filename in enumerate(reversed(filenames)):
|
||||
contains_bad_substrings = [ss in filename for ss in _BAD_FILE_SUBSTRINGS]
|
||||
if not any(contains_bad_substrings):
|
||||
return size - idx - 1
|
||||
return 0
|
||||
|
||||
|
||||
def interpolate(error_message, graph):
|
||||
"""Interpolates an error message.
|
||||
|
||||
The error message can contain tags of the form ^^type:name:format^^ which will
|
||||
@ -82,11 +129,38 @@ def interpolate(error_message):
|
||||
|
||||
Args:
|
||||
error_message: A string to interpolate.
|
||||
graph: ops.Graph object containing all nodes referenced in the error
|
||||
message.
|
||||
|
||||
Returns:
|
||||
The string with tags of the form ^^type:name:format^^ interpolated.
|
||||
"""
|
||||
seps, tags = _parse_message(error_message)
|
||||
subs = [string.Template(tag.format).safe_substitute({}) for tag in tags]
|
||||
|
||||
node_name_to_substitution_dict = {}
|
||||
for name in [t.name for t in tags]:
|
||||
try:
|
||||
op = graph.get_operation_by_name(name)
|
||||
except KeyError:
|
||||
op = None
|
||||
|
||||
if op:
|
||||
frame_index = _find_index_of_defining_frame_for_op(op)
|
||||
# pylint: disable=protected-access
|
||||
field_dict = _get_field_dict_from_traceback(op._traceback, frame_index)
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
field_dict = {
|
||||
"file": "<NA>",
|
||||
"line": "<NA>",
|
||||
"func": "<NA>",
|
||||
"code": None,
|
||||
}
|
||||
node_name_to_substitution_dict[name] = field_dict
|
||||
|
||||
subs = [
|
||||
string.Template(tag.format).safe_substitute(
|
||||
node_name_to_substitution_dict[tag.name]) for tag in tags
|
||||
]
|
||||
return "".join(
|
||||
itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue="")))
|
||||
|
@ -18,31 +18,115 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import error_interpolation
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import tf_stack
|
||||
|
||||
|
||||
def _make_frame_with_filename(op, idx, filename):
|
||||
"""Return a copy of an existing stack frame with a new filename."""
|
||||
stack_frame = list(op._traceback[idx])
|
||||
stack_frame[tf_stack.TB_FILENAME] = filename
|
||||
return tuple(stack_frame)
|
||||
|
||||
|
||||
def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,
|
||||
num_inner_tf_frames):
|
||||
"""Replace op._traceback with a new traceback using special filenames."""
|
||||
tf_filename = "%d" + error_interpolation._BAD_FILE_SUBSTRINGS[0]
|
||||
user_filename = "%d/my_favorite_file.py"
|
||||
|
||||
num_requested_frames = num_user_frames + num_inner_tf_frames
|
||||
num_actual_frames = len(op._traceback)
|
||||
num_outer_frames = num_actual_frames - num_requested_frames
|
||||
assert num_requested_frames <= num_actual_frames, "Too few real frames."
|
||||
|
||||
# The op's traceback has outermost frame at index 0.
|
||||
stack = []
|
||||
for idx in range(0, num_outer_frames):
|
||||
stack.append(op._traceback[idx])
|
||||
for idx in range(len(stack), len(stack)+num_user_frames):
|
||||
stack.append(_make_frame_with_filename(op, idx, user_filename % idx))
|
||||
for idx in range(len(stack), len(stack)+num_inner_tf_frames):
|
||||
stack.append(_make_frame_with_filename(op, idx, tf_filename % idx))
|
||||
op._traceback = stack
|
||||
|
||||
|
||||
class InterpolateTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Add nodes to the graph for retrieval by name later.
|
||||
constant_op.constant(1, name="One")
|
||||
constant_op.constant(2, name="Two")
|
||||
three = constant_op.constant(3, name="Three")
|
||||
self.graph = three.graph
|
||||
|
||||
# Change the list of bad file substrings so that constant_op.py is chosen
|
||||
# as the defining stack frame for constant_op.constant ops.
|
||||
self.old_bad_strings = error_interpolation._BAD_FILE_SUBSTRINGS
|
||||
error_interpolation._BAD_FILE_SUBSTRINGS = ["/ops.py", "/util"]
|
||||
|
||||
def tearDown(self):
|
||||
error_interpolation._BAD_FILE_SUBSTRINGS = self.old_bad_strings
|
||||
|
||||
def testFindIndexOfDefiningFrameForOp(self):
|
||||
local_op = constant_op.constant(42).op
|
||||
user_filename = "hope.py"
|
||||
_modify_op_stack_with_filenames(local_op,
|
||||
num_user_frames=3,
|
||||
user_filename=user_filename,
|
||||
num_inner_tf_frames=5)
|
||||
idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
|
||||
# Expected frame is 6th from the end because there are 5 inner frames witih
|
||||
# TF filenames.
|
||||
expected_frame = len(local_op._traceback) - 6
|
||||
self.assertEqual(expected_frame, idx)
|
||||
|
||||
def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self):
|
||||
local_op = constant_op.constant(43).op
|
||||
# Truncate stack to known length.
|
||||
local_op._traceback = local_op._traceback[:7]
|
||||
# Ensure all frames look like TF frames.
|
||||
_modify_op_stack_with_filenames(local_op,
|
||||
num_user_frames=0,
|
||||
user_filename="user_file.py",
|
||||
num_inner_tf_frames=7)
|
||||
idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
|
||||
self.assertEqual(0, idx)
|
||||
|
||||
def testNothingToDo(self):
|
||||
normal_string = "This is just a normal string"
|
||||
interpolated_string = error_interpolation.interpolate(normal_string)
|
||||
interpolated_string = error_interpolation.interpolate(normal_string,
|
||||
self.graph)
|
||||
self.assertEqual(interpolated_string, normal_string)
|
||||
|
||||
def testOneTag(self):
|
||||
one_tag_string = "^^node:Foo:${file}^^"
|
||||
interpolated_string = error_interpolation.interpolate(one_tag_string)
|
||||
self.assertEqual(interpolated_string, "${file}")
|
||||
one_tag_string = "^^node:Two:${file}^^"
|
||||
interpolated_string = error_interpolation.interpolate(one_tag_string,
|
||||
self.graph)
|
||||
self.assertTrue(interpolated_string.endswith("constant_op.py"),
|
||||
"interpolated_string '%s' did not end with constant_op.py"
|
||||
% interpolated_string)
|
||||
|
||||
def testOneTagWithAFakeNameResultsInPlaceholders(self):
|
||||
one_tag_string = "^^node:MinusOne:${file}^^"
|
||||
interpolated_string = error_interpolation.interpolate(one_tag_string,
|
||||
self.graph)
|
||||
self.assertEqual(interpolated_string, "<NA>")
|
||||
|
||||
def testTwoTagsNoSeps(self):
|
||||
two_tags_no_seps = "^^node:Foo:${file}^^^^node:Bar:${line}^^"
|
||||
interpolated_string = error_interpolation.interpolate(two_tags_no_seps)
|
||||
self.assertEqual(interpolated_string, "${file}${line}")
|
||||
two_tags_no_seps = "^^node:One:${file}^^^^node:Three:${line}^^"
|
||||
interpolated_string = error_interpolation.interpolate(two_tags_no_seps,
|
||||
self.graph)
|
||||
self.assertRegexpMatches(interpolated_string, "constant_op.py[0-9]+")
|
||||
|
||||
def testTwoTagsWithSeps(self):
|
||||
two_tags_with_seps = "123^^node:Foo:${file}^^456^^node:Bar:${line}^^789"
|
||||
interpolated_string = error_interpolation.interpolate(two_tags_with_seps)
|
||||
self.assertEqual(interpolated_string, "123${file}456${line}789")
|
||||
two_tags_with_seps = ";;;^^node:Two:${file}^^,,,^^node:Three:${line}^^;;;"
|
||||
interpolated_string = error_interpolation.interpolate(two_tags_with_seps,
|
||||
self.graph)
|
||||
expected_regex = "^;;;.*constant_op.py,,,[0-9]*;;;$"
|
||||
self.assertRegexpMatches(interpolated_string, expected_regex)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -21,6 +21,12 @@ from __future__ import print_function
|
||||
import linecache
|
||||
import sys
|
||||
|
||||
# Names for indices into TF traceback tuples.
|
||||
TB_FILENAME = 0
|
||||
TB_LINENO = 1
|
||||
TB_FUNCNAME = 2
|
||||
TB_CODEDICT = 3 # Dictionary of Python interpreter state.
|
||||
|
||||
|
||||
def extract_stack(extract_frame_info_fn=None):
|
||||
"""A lightweight, extensible re-implementation of traceback.extract_stack.
|
||||
|
Loading…
x
Reference in New Issue
Block a user