Make error interpolation and debug info tolerant of ops without a traceback.

* This was encountered when attempting to save debug info for a SavedModel with a functional while loop. The Placeholder op of the while loop condition lacked a _traceback attribute, causing save to fail.
* I think this is expected behavior (no traceback on synthetic ops). Even if not, I believe that having the error_interpolator be conservatively correct in the face of a missing traceback is proper.
* Cleans up some protected functions to take a traceback instead of an op.
* Makes compute_field_dict protected since it is not used outside of this module.

PiperOrigin-RevId: 280244393
Change-Id: I3af66387c6d3af1a779c84775ec4704e2b354859
This commit is contained in:
A. Unique TensorFlower 2019-11-13 11:43:08 -08:00 committed by TensorFlower Gardener
parent af1cb5fe10
commit 2fb9116740
2 changed files with 55 additions and 32 deletions

View File

@ -223,7 +223,7 @@ def _is_framework_filename(filename):
return False
def _find_index_of_defining_frame_for_op(op):
def _find_index_of_defining_frame(traceback):
"""Return index in op.traceback with first 'useful' frame.
This method reads through the stack stored in op.traceback looking for the
@ -232,18 +232,16 @@ def _find_index_of_defining_frame_for_op(op):
pattern matching the filename).
Args:
op: the Operation object for which we would like to find the defining
location.
traceback: A list of traceback frames (as from Operation.traceback).
Returns:
Integer index into op.traceback where the first non-TF file was found
(innermost to outermost), or 0 (for the outermost stack frame) if all files
came from TensorFlow.
"""
# Index 0 of tf_traceback is the outermost frame.
tf_traceback = op.traceback
size = len(tf_traceback)
filenames = [frame.filename for frame in tf_traceback]
# Index 0 of traceback is the outermost frame.
size = len(traceback)
filenames = [frame.filename for frame in traceback]
# We process the filenames from the innermost frame to outermost.
for idx, filename in enumerate(reversed(filenames)):
is_framework = _is_framework_filename(filename)
@ -253,13 +251,13 @@ def _find_index_of_defining_frame_for_op(op):
return 0
def _get_defining_frame_from_op(op):
def _get_defining_frame(traceback):
"""Find and return stack frame where op was defined."""
frame_index = _find_index_of_defining_frame_for_op(op)
return op.traceback[frame_index]
frame_index = _find_index_of_defining_frame(traceback)
return traceback[frame_index]
def _compute_useful_frames(op, num):
def _compute_useful_frames(traceback, num):
"""Return a list of frames, which form a 'useful' stack.
Starting from the defining frame to the outermost one, this method computes
@ -267,20 +265,20 @@ def _compute_useful_frames(op, num):
frames.
Args:
op: op.Operation object having a _traceback member.
traceback: A list of traceback frames (as from Operation.traceback).
num: total number of frames to return.
Returns:
A list of frames.
"""
defining_frame_index = _find_index_of_defining_frame_for_op(op)
defining_frame_index = _find_index_of_defining_frame(traceback)
# The stack trace is collected from two lines before the defining frame in the
# model file to the outermost with `num` frames at most. These two extra lines
# are included from the TensorFlow library to give the context which node is
# defined.
innermost_excluded = min(defining_frame_index + 2 + 1, len(op.traceback))
innermost_excluded = min(defining_frame_index + 2 + 1, len(traceback))
outermost_included = max(innermost_excluded - num, 0)
return op.traceback[outermost_included:innermost_excluded]
return traceback[outermost_included:innermost_excluded]
def create_graph_debug_info_def(func_named_operations):
@ -305,9 +303,16 @@ def create_graph_debug_info_def(func_named_operations):
all_file_names = set()
node_to_trace = {}
for func_name, op in func_named_operations:
try:
op_traceback = op.traceback
except AttributeError:
# Some ops synthesized on as part of function or control flow definition
# do not have tracebacks.
continue
# Gets the stack trace of the operation and then the file location.
node_name = op.name + "@" + func_name
node_to_trace[node_name] = _compute_useful_frames(op, 10)
node_to_trace[node_name] = _compute_useful_frames(op_traceback, 10)
for frame in node_to_trace[node_name]:
all_file_names.add(frame.filename)
@ -332,7 +337,7 @@ def create_graph_debug_info_def(func_named_operations):
return graph_debug_info_def
def compute_field_dict(op, strip_file_prefix=""):
def _compute_field_dict(op, strip_file_prefix=""):
"""Return a dictionary mapping interpolation tokens to values.
Args:
@ -364,23 +369,34 @@ def compute_field_dict(op, strip_file_prefix=""):
with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
}
"""
frame = _get_defining_frame_from_op(op)
filename = frame.filename
if filename.startswith(strip_file_prefix):
filename = filename[len(strip_file_prefix):]
lineno = frame.lineno
defined_at = " (defined at %s:%d)" % (filename, lineno)
colocation_summary = _compute_colocation_summary_from_op(op)
device_summary = _compute_device_assignment_summary_from_op(op)
combined_summary = "\n".join([colocation_summary, device_summary])
# Optional traceback info.
try:
traceback = op.traceback
except AttributeError:
# Some ops synthesized on as part of function or control flow definition
# do not have tracebacks.
filename = "<unknown>"
lineno = 0
defined_at = " (defined at <unknown>)"
else:
frame = _get_defining_frame(traceback)
filename = frame.filename
if filename.startswith(strip_file_prefix):
filename = filename[len(strip_file_prefix):]
lineno = frame.lineno
defined_at = " (defined at %s:%d)" % (filename, lineno)
field_dict = {
"file": filename,
"line": lineno,
"defined_at": defined_at,
"colocations": colocation_summary,
"devices": device_summary,
"devs_and_colocs": combined_summary,
"defined_at": defined_at,
"file": filename,
"line": lineno,
}
return field_dict
@ -448,14 +464,14 @@ def _build_error_message(op, input_ops, common_prefix):
The formatted error message for the given op. The error message also
includes the information about the input sources for the given op.
"""
field_dict = compute_field_dict(op, common_prefix)
field_dict = _compute_field_dict(op, common_prefix)
msg = "node %s%s " % (op.name, field_dict["defined_at"])
input_debug_info = []
# This stores the line numbers that we have already printed.
done = set()
done.add(field_dict["defined_at"])
for op_inp in input_ops:
field_dict_inp = compute_field_dict(op_inp, common_prefix)
field_dict_inp = _compute_field_dict(op_inp, common_prefix)
if field_dict_inp["defined_at"] not in done:
input_debug_info.append(
" %s%s" % (op_inp.name, field_dict_inp["defined_at"]))
@ -507,7 +523,7 @@ def interpolate(error_message, graph):
if source_msg:
end_msg["source_nodes"].append(source_msg)
elif tag.type == "colocation_node":
field_dict = compute_field_dict(ops[0], common_prefix)
field_dict = _compute_field_dict(ops[0], common_prefix)
msg = "node %s%s placed on device %s " % (
ops[0].name, field_dict["defined_at"], field_dict["devices"])
end_msg["colocations"].append(field_dict["devs_and_colocs"])

View File

@ -155,9 +155,13 @@ class CreateGraphDebugInfoDefTest(test.TestCase):
global_op = constant_op.constant(0, name="Global").op
op1 = constant_op.constant(1, name="One").op
op2 = constant_op.constant(2, name="Two").op
non_traceback_op = constant_op.constant(3, name="NonTraceback").op
# Ensure op without traceback does not fail
del non_traceback_op._traceback
# pyformat: enable
export_ops = [("", global_op), ("func1", op1), ("func2", op2)]
export_ops = [("", global_op), ("func1", op1), ("func2", op2),
("func2", non_traceback_op)]
graph_debug_info = error_interpolation.create_graph_debug_info_def(
export_ops)
this_file_index = -1
@ -201,7 +205,7 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
num_user_frames=3,
user_filename=user_filename,
num_inner_tf_frames=5)
idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
idx = error_interpolation._find_index_of_defining_frame(local_op._traceback)
# Expected frame is 6th from the end because there are 5 inner frames witih
# TF filenames.
expected_frame = len(local_op._traceback) - 6
@ -217,7 +221,7 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
num_user_frames=0,
user_filename="user_file.py",
num_inner_tf_frames=7)
idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
idx = error_interpolation._find_index_of_defining_frame(local_op._traceback)
self.assertEqual(0, idx)
def testNothingToDo(self):
@ -264,6 +268,9 @@ class InputNodesTest(test.TestCase):
one = constant_op.constant(1, name="One")
two = constant_op.constant(2, name="Two")
three = math_ops.add(one, two, name="Three")
non_traceback_op = constant_op.constant(3, name="NonTraceback")
# Ensure op without traceback does not fail
del non_traceback_op.op._traceback
self.graph = three.graph
def testNoInputs(self):