Add initial support for interpolating filename and line number in error messages returned from C++.
PiperOrigin-RevId: 204455158
This commit is contained in:
parent
e438b192d2
commit
895a766788
@ -705,7 +705,9 @@ py_library(
|
|||||||
"framework/error_interpolation.py",
|
"framework/error_interpolation.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [],
|
deps = [
|
||||||
|
":util",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
@ -1040,6 +1042,7 @@ py_test(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":client_testlib",
|
":client_testlib",
|
||||||
|
":constant_op",
|
||||||
":error_interpolation",
|
":error_interpolation",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -29,6 +29,9 @@ import string
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.util import tf_stack
|
||||||
|
|
||||||
|
|
||||||
_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?"
|
_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?"
|
||||||
_FORMAT_REGEX = r"[A-Za-z0-9_.\-/${}:]+"
|
_FORMAT_REGEX = r"[A-Za-z0-9_.\-/${}:]+"
|
||||||
_TAG_REGEX = r"\^\^({name}):({name}):({fmt})\^\^".format(
|
_TAG_REGEX = r"\^\^({name}):({name}):({fmt})\^\^".format(
|
||||||
@ -38,6 +41,8 @@ _INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX)
|
|||||||
|
|
||||||
_ParseTag = collections.namedtuple("_ParseTag", ["type", "name", "format"])
|
_ParseTag = collections.namedtuple("_ParseTag", ["type", "name", "format"])
|
||||||
|
|
||||||
|
_BAD_FILE_SUBSTRINGS = ["tensorflow/python", "<embedded"]
|
||||||
|
|
||||||
|
|
||||||
def _parse_message(message):
|
def _parse_message(message):
|
||||||
"""Parses the message.
|
"""Parses the message.
|
||||||
@ -48,6 +53,10 @@ def _parse_message(message):
|
|||||||
"123^^node:Foo:${file}^^456^^node:Bar:${line}^^789", there are two tags and
|
"123^^node:Foo:${file}^^456^^node:Bar:${line}^^789", there are two tags and
|
||||||
three separators. The separators are the numeric characters.
|
three separators. The separators are the numeric characters.
|
||||||
|
|
||||||
|
Supported tags after node:<node_name>
|
||||||
|
file: Replaced with the filename in which the node was defined.
|
||||||
|
line: Replaced by the line number at which the node was defined.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: String to parse
|
message: String to parse
|
||||||
|
|
||||||
@ -72,9 +81,47 @@ def _parse_message(message):
|
|||||||
return seps, tags
|
return seps, tags
|
||||||
|
|
||||||
|
|
||||||
# TODO(jtkeeling): Modify to actually interpolate format strings rather than
|
def _get_field_dict_from_traceback(tf_traceback, frame_index):
|
||||||
# echoing them.
|
"""Convert traceback elements into interpolation dictionary and return."""
|
||||||
def interpolate(error_message):
|
frame = tf_traceback[frame_index]
|
||||||
|
return {
|
||||||
|
"file": frame[tf_stack.TB_FILENAME],
|
||||||
|
"line": frame[tf_stack.TB_LINENO],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _find_index_of_defining_frame_for_op(op):
|
||||||
|
"""Return index in op._traceback with first 'useful' frame.
|
||||||
|
|
||||||
|
This method reads through the stack stored in op._traceback looking for the
|
||||||
|
innermost frame which (hopefully) belongs to the caller. It accomplishes this
|
||||||
|
by rejecting frames whose filename appears to come from TensorFlow (see
|
||||||
|
error_interpolation._BAD_FILE_SUBSTRINGS for the list of rejected substrings).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: the Operation object for which we would like to find the defining
|
||||||
|
location.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer index into op._traceback where the first non-TF file was found
|
||||||
|
(innermost to outermost), or 0 (for the outermost stack frame) if all files
|
||||||
|
came from TensorFlow.
|
||||||
|
"""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
# Index 0 of tf_traceback is the outermost frame.
|
||||||
|
tf_traceback = tf_stack.convert_stack(op._traceback)
|
||||||
|
size = len(tf_traceback)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
filenames = [frame[tf_stack.TB_FILENAME] for frame in tf_traceback]
|
||||||
|
# We process the filenames from the innermost frame to outermost.
|
||||||
|
for idx, filename in enumerate(reversed(filenames)):
|
||||||
|
contains_bad_substrings = [ss in filename for ss in _BAD_FILE_SUBSTRINGS]
|
||||||
|
if not any(contains_bad_substrings):
|
||||||
|
return size - idx - 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate(error_message, graph):
|
||||||
"""Interpolates an error message.
|
"""Interpolates an error message.
|
||||||
|
|
||||||
The error message can contain tags of the form ^^type:name:format^^ which will
|
The error message can contain tags of the form ^^type:name:format^^ which will
|
||||||
@ -82,11 +129,38 @@ def interpolate(error_message):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
error_message: A string to interpolate.
|
error_message: A string to interpolate.
|
||||||
|
graph: ops.Graph object containing all nodes referenced in the error
|
||||||
|
message.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The string with tags of the form ^^type:name:format^^ interpolated.
|
The string with tags of the form ^^type:name:format^^ interpolated.
|
||||||
"""
|
"""
|
||||||
seps, tags = _parse_message(error_message)
|
seps, tags = _parse_message(error_message)
|
||||||
subs = [string.Template(tag.format).safe_substitute({}) for tag in tags]
|
|
||||||
|
node_name_to_substitution_dict = {}
|
||||||
|
for name in [t.name for t in tags]:
|
||||||
|
try:
|
||||||
|
op = graph.get_operation_by_name(name)
|
||||||
|
except KeyError:
|
||||||
|
op = None
|
||||||
|
|
||||||
|
if op:
|
||||||
|
frame_index = _find_index_of_defining_frame_for_op(op)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
field_dict = _get_field_dict_from_traceback(op._traceback, frame_index)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
else:
|
||||||
|
field_dict = {
|
||||||
|
"file": "<NA>",
|
||||||
|
"line": "<NA>",
|
||||||
|
"func": "<NA>",
|
||||||
|
"code": None,
|
||||||
|
}
|
||||||
|
node_name_to_substitution_dict[name] = field_dict
|
||||||
|
|
||||||
|
subs = [
|
||||||
|
string.Template(tag.format).safe_substitute(
|
||||||
|
node_name_to_substitution_dict[tag.name]) for tag in tags
|
||||||
|
]
|
||||||
return "".join(
|
return "".join(
|
||||||
itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue="")))
|
itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue="")))
|
||||||
|
@ -18,31 +18,115 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import error_interpolation
|
from tensorflow.python.framework import error_interpolation
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.util import tf_stack
|
||||||
|
|
||||||
|
|
||||||
|
def _make_frame_with_filename(op, idx, filename):
|
||||||
|
"""Return a copy of an existing stack frame with a new filename."""
|
||||||
|
stack_frame = list(op._traceback[idx])
|
||||||
|
stack_frame[tf_stack.TB_FILENAME] = filename
|
||||||
|
return tuple(stack_frame)
|
||||||
|
|
||||||
|
|
||||||
|
def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,
|
||||||
|
num_inner_tf_frames):
|
||||||
|
"""Replace op._traceback with a new traceback using special filenames."""
|
||||||
|
tf_filename = "%d" + error_interpolation._BAD_FILE_SUBSTRINGS[0]
|
||||||
|
user_filename = "%d/my_favorite_file.py"
|
||||||
|
|
||||||
|
num_requested_frames = num_user_frames + num_inner_tf_frames
|
||||||
|
num_actual_frames = len(op._traceback)
|
||||||
|
num_outer_frames = num_actual_frames - num_requested_frames
|
||||||
|
assert num_requested_frames <= num_actual_frames, "Too few real frames."
|
||||||
|
|
||||||
|
# The op's traceback has outermost frame at index 0.
|
||||||
|
stack = []
|
||||||
|
for idx in range(0, num_outer_frames):
|
||||||
|
stack.append(op._traceback[idx])
|
||||||
|
for idx in range(len(stack), len(stack)+num_user_frames):
|
||||||
|
stack.append(_make_frame_with_filename(op, idx, user_filename % idx))
|
||||||
|
for idx in range(len(stack), len(stack)+num_inner_tf_frames):
|
||||||
|
stack.append(_make_frame_with_filename(op, idx, tf_filename % idx))
|
||||||
|
op._traceback = stack
|
||||||
|
|
||||||
|
|
||||||
class InterpolateTest(test.TestCase):
|
class InterpolateTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# Add nodes to the graph for retrieval by name later.
|
||||||
|
constant_op.constant(1, name="One")
|
||||||
|
constant_op.constant(2, name="Two")
|
||||||
|
three = constant_op.constant(3, name="Three")
|
||||||
|
self.graph = three.graph
|
||||||
|
|
||||||
|
# Change the list of bad file substrings so that constant_op.py is chosen
|
||||||
|
# as the defining stack frame for constant_op.constant ops.
|
||||||
|
self.old_bad_strings = error_interpolation._BAD_FILE_SUBSTRINGS
|
||||||
|
error_interpolation._BAD_FILE_SUBSTRINGS = ["/ops.py", "/util"]
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
error_interpolation._BAD_FILE_SUBSTRINGS = self.old_bad_strings
|
||||||
|
|
||||||
|
def testFindIndexOfDefiningFrameForOp(self):
|
||||||
|
local_op = constant_op.constant(42).op
|
||||||
|
user_filename = "hope.py"
|
||||||
|
_modify_op_stack_with_filenames(local_op,
|
||||||
|
num_user_frames=3,
|
||||||
|
user_filename=user_filename,
|
||||||
|
num_inner_tf_frames=5)
|
||||||
|
idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
|
||||||
|
# Expected frame is 6th from the end because there are 5 inner frames witih
|
||||||
|
# TF filenames.
|
||||||
|
expected_frame = len(local_op._traceback) - 6
|
||||||
|
self.assertEqual(expected_frame, idx)
|
||||||
|
|
||||||
|
def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self):
|
||||||
|
local_op = constant_op.constant(43).op
|
||||||
|
# Truncate stack to known length.
|
||||||
|
local_op._traceback = local_op._traceback[:7]
|
||||||
|
# Ensure all frames look like TF frames.
|
||||||
|
_modify_op_stack_with_filenames(local_op,
|
||||||
|
num_user_frames=0,
|
||||||
|
user_filename="user_file.py",
|
||||||
|
num_inner_tf_frames=7)
|
||||||
|
idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
|
||||||
|
self.assertEqual(0, idx)
|
||||||
|
|
||||||
def testNothingToDo(self):
|
def testNothingToDo(self):
|
||||||
normal_string = "This is just a normal string"
|
normal_string = "This is just a normal string"
|
||||||
interpolated_string = error_interpolation.interpolate(normal_string)
|
interpolated_string = error_interpolation.interpolate(normal_string,
|
||||||
|
self.graph)
|
||||||
self.assertEqual(interpolated_string, normal_string)
|
self.assertEqual(interpolated_string, normal_string)
|
||||||
|
|
||||||
def testOneTag(self):
|
def testOneTag(self):
|
||||||
one_tag_string = "^^node:Foo:${file}^^"
|
one_tag_string = "^^node:Two:${file}^^"
|
||||||
interpolated_string = error_interpolation.interpolate(one_tag_string)
|
interpolated_string = error_interpolation.interpolate(one_tag_string,
|
||||||
self.assertEqual(interpolated_string, "${file}")
|
self.graph)
|
||||||
|
self.assertTrue(interpolated_string.endswith("constant_op.py"),
|
||||||
|
"interpolated_string '%s' did not end with constant_op.py"
|
||||||
|
% interpolated_string)
|
||||||
|
|
||||||
|
def testOneTagWithAFakeNameResultsInPlaceholders(self):
|
||||||
|
one_tag_string = "^^node:MinusOne:${file}^^"
|
||||||
|
interpolated_string = error_interpolation.interpolate(one_tag_string,
|
||||||
|
self.graph)
|
||||||
|
self.assertEqual(interpolated_string, "<NA>")
|
||||||
|
|
||||||
def testTwoTagsNoSeps(self):
|
def testTwoTagsNoSeps(self):
|
||||||
two_tags_no_seps = "^^node:Foo:${file}^^^^node:Bar:${line}^^"
|
two_tags_no_seps = "^^node:One:${file}^^^^node:Three:${line}^^"
|
||||||
interpolated_string = error_interpolation.interpolate(two_tags_no_seps)
|
interpolated_string = error_interpolation.interpolate(two_tags_no_seps,
|
||||||
self.assertEqual(interpolated_string, "${file}${line}")
|
self.graph)
|
||||||
|
self.assertRegexpMatches(interpolated_string, "constant_op.py[0-9]+")
|
||||||
|
|
||||||
def testTwoTagsWithSeps(self):
|
def testTwoTagsWithSeps(self):
|
||||||
two_tags_with_seps = "123^^node:Foo:${file}^^456^^node:Bar:${line}^^789"
|
two_tags_with_seps = ";;;^^node:Two:${file}^^,,,^^node:Three:${line}^^;;;"
|
||||||
interpolated_string = error_interpolation.interpolate(two_tags_with_seps)
|
interpolated_string = error_interpolation.interpolate(two_tags_with_seps,
|
||||||
self.assertEqual(interpolated_string, "123${file}456${line}789")
|
self.graph)
|
||||||
|
expected_regex = "^;;;.*constant_op.py,,,[0-9]*;;;$"
|
||||||
|
self.assertRegexpMatches(interpolated_string, expected_regex)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -21,6 +21,12 @@ from __future__ import print_function
|
|||||||
import linecache
|
import linecache
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
# Names for indices into TF traceback tuples.
|
||||||
|
TB_FILENAME = 0
|
||||||
|
TB_LINENO = 1
|
||||||
|
TB_FUNCNAME = 2
|
||||||
|
TB_CODEDICT = 3 # Dictionary of Python interpreter state.
|
||||||
|
|
||||||
|
|
||||||
def extract_stack(extract_frame_info_fn=None):
|
def extract_stack(extract_frame_info_fn=None):
|
||||||
"""A lightweight, extensible re-implementation of traceback.extract_stack.
|
"""A lightweight, extensible re-implementation of traceback.extract_stack.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user