Make error interpolation and debug info tolerant of ops without a traceback.
* This was encountered when attempting to save debug info for a SavedModel with a functional while loop. The Placeholder op of the while loop condition lacked a _traceback attribute, causing save to fail. * I think this is expected behavior (no traceback on synthetic ops). Even if not, I believe that having the error_interpolator be conservatively correct in the face of a missing traceback is proper. * Cleans up some protected functions to take a traceback instead of an op. * Makes compute_field_dict protected since it is not used outside of this module. PiperOrigin-RevId: 280244393 Change-Id: I3af66387c6d3af1a779c84775ec4704e2b354859
This commit is contained in:
parent
af1cb5fe10
commit
2fb9116740
@ -223,7 +223,7 @@ def _is_framework_filename(filename):
|
||||
return False
|
||||
|
||||
|
||||
def _find_index_of_defining_frame_for_op(op):
|
||||
def _find_index_of_defining_frame(traceback):
|
||||
"""Return index in op.traceback with first 'useful' frame.
|
||||
|
||||
This method reads through the stack stored in op.traceback looking for the
|
||||
@ -232,18 +232,16 @@ def _find_index_of_defining_frame_for_op(op):
|
||||
pattern matching the filename).
|
||||
|
||||
Args:
|
||||
op: the Operation object for which we would like to find the defining
|
||||
location.
|
||||
traceback: A list of traceback frames (as from Operation.traceback).
|
||||
|
||||
Returns:
|
||||
Integer index into op.traceback where the first non-TF file was found
|
||||
(innermost to outermost), or 0 (for the outermost stack frame) if all files
|
||||
came from TensorFlow.
|
||||
"""
|
||||
# Index 0 of tf_traceback is the outermost frame.
|
||||
tf_traceback = op.traceback
|
||||
size = len(tf_traceback)
|
||||
filenames = [frame.filename for frame in tf_traceback]
|
||||
# Index 0 of traceback is the outermost frame.
|
||||
size = len(traceback)
|
||||
filenames = [frame.filename for frame in traceback]
|
||||
# We process the filenames from the innermost frame to outermost.
|
||||
for idx, filename in enumerate(reversed(filenames)):
|
||||
is_framework = _is_framework_filename(filename)
|
||||
@ -253,13 +251,13 @@ def _find_index_of_defining_frame_for_op(op):
|
||||
return 0
|
||||
|
||||
|
||||
def _get_defining_frame_from_op(op):
|
||||
def _get_defining_frame(traceback):
|
||||
"""Find and return stack frame where op was defined."""
|
||||
frame_index = _find_index_of_defining_frame_for_op(op)
|
||||
return op.traceback[frame_index]
|
||||
frame_index = _find_index_of_defining_frame(traceback)
|
||||
return traceback[frame_index]
|
||||
|
||||
|
||||
def _compute_useful_frames(op, num):
|
||||
def _compute_useful_frames(traceback, num):
|
||||
"""Return a list of frames, which form a 'useful' stack.
|
||||
|
||||
Starting from the defining frame to the outermost one, this method computes
|
||||
@ -267,20 +265,20 @@ def _compute_useful_frames(op, num):
|
||||
frames.
|
||||
|
||||
Args:
|
||||
op: op.Operation object having a _traceback member.
|
||||
traceback: A list of traceback frames (as from Operation.traceback).
|
||||
num: total number of frames to return.
|
||||
|
||||
Returns:
|
||||
A list of frames.
|
||||
"""
|
||||
defining_frame_index = _find_index_of_defining_frame_for_op(op)
|
||||
defining_frame_index = _find_index_of_defining_frame(traceback)
|
||||
# The stack trace is collected from two lines before the defining frame in the
|
||||
# model file to the outermost with `num` frames at most. These two extra lines
|
||||
# are included from the TensorFlow library to give the context which node is
|
||||
# defined.
|
||||
innermost_excluded = min(defining_frame_index + 2 + 1, len(op.traceback))
|
||||
innermost_excluded = min(defining_frame_index + 2 + 1, len(traceback))
|
||||
outermost_included = max(innermost_excluded - num, 0)
|
||||
return op.traceback[outermost_included:innermost_excluded]
|
||||
return traceback[outermost_included:innermost_excluded]
|
||||
|
||||
|
||||
def create_graph_debug_info_def(func_named_operations):
|
||||
@ -305,9 +303,16 @@ def create_graph_debug_info_def(func_named_operations):
|
||||
all_file_names = set()
|
||||
node_to_trace = {}
|
||||
for func_name, op in func_named_operations:
|
||||
try:
|
||||
op_traceback = op.traceback
|
||||
except AttributeError:
|
||||
# Some ops synthesized on as part of function or control flow definition
|
||||
# do not have tracebacks.
|
||||
continue
|
||||
|
||||
# Gets the stack trace of the operation and then the file location.
|
||||
node_name = op.name + "@" + func_name
|
||||
node_to_trace[node_name] = _compute_useful_frames(op, 10)
|
||||
node_to_trace[node_name] = _compute_useful_frames(op_traceback, 10)
|
||||
for frame in node_to_trace[node_name]:
|
||||
all_file_names.add(frame.filename)
|
||||
|
||||
@ -332,7 +337,7 @@ def create_graph_debug_info_def(func_named_operations):
|
||||
return graph_debug_info_def
|
||||
|
||||
|
||||
def compute_field_dict(op, strip_file_prefix=""):
|
||||
def _compute_field_dict(op, strip_file_prefix=""):
|
||||
"""Return a dictionary mapping interpolation tokens to values.
|
||||
|
||||
Args:
|
||||
@ -364,23 +369,34 @@ def compute_field_dict(op, strip_file_prefix=""):
|
||||
with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
|
||||
}
|
||||
"""
|
||||
frame = _get_defining_frame_from_op(op)
|
||||
filename = frame.filename
|
||||
if filename.startswith(strip_file_prefix):
|
||||
filename = filename[len(strip_file_prefix):]
|
||||
lineno = frame.lineno
|
||||
defined_at = " (defined at %s:%d)" % (filename, lineno)
|
||||
colocation_summary = _compute_colocation_summary_from_op(op)
|
||||
device_summary = _compute_device_assignment_summary_from_op(op)
|
||||
combined_summary = "\n".join([colocation_summary, device_summary])
|
||||
|
||||
# Optional traceback info.
|
||||
try:
|
||||
traceback = op.traceback
|
||||
except AttributeError:
|
||||
# Some ops synthesized on as part of function or control flow definition
|
||||
# do not have tracebacks.
|
||||
filename = "<unknown>"
|
||||
lineno = 0
|
||||
defined_at = " (defined at <unknown>)"
|
||||
else:
|
||||
frame = _get_defining_frame(traceback)
|
||||
filename = frame.filename
|
||||
if filename.startswith(strip_file_prefix):
|
||||
filename = filename[len(strip_file_prefix):]
|
||||
lineno = frame.lineno
|
||||
defined_at = " (defined at %s:%d)" % (filename, lineno)
|
||||
|
||||
field_dict = {
|
||||
"file": filename,
|
||||
"line": lineno,
|
||||
"defined_at": defined_at,
|
||||
"colocations": colocation_summary,
|
||||
"devices": device_summary,
|
||||
"devs_and_colocs": combined_summary,
|
||||
"defined_at": defined_at,
|
||||
"file": filename,
|
||||
"line": lineno,
|
||||
}
|
||||
return field_dict
|
||||
|
||||
@ -448,14 +464,14 @@ def _build_error_message(op, input_ops, common_prefix):
|
||||
The formatted error message for the given op. The error message also
|
||||
includes the information about the input sources for the given op.
|
||||
"""
|
||||
field_dict = compute_field_dict(op, common_prefix)
|
||||
field_dict = _compute_field_dict(op, common_prefix)
|
||||
msg = "node %s%s " % (op.name, field_dict["defined_at"])
|
||||
input_debug_info = []
|
||||
# This stores the line numbers that we have already printed.
|
||||
done = set()
|
||||
done.add(field_dict["defined_at"])
|
||||
for op_inp in input_ops:
|
||||
field_dict_inp = compute_field_dict(op_inp, common_prefix)
|
||||
field_dict_inp = _compute_field_dict(op_inp, common_prefix)
|
||||
if field_dict_inp["defined_at"] not in done:
|
||||
input_debug_info.append(
|
||||
" %s%s" % (op_inp.name, field_dict_inp["defined_at"]))
|
||||
@ -507,7 +523,7 @@ def interpolate(error_message, graph):
|
||||
if source_msg:
|
||||
end_msg["source_nodes"].append(source_msg)
|
||||
elif tag.type == "colocation_node":
|
||||
field_dict = compute_field_dict(ops[0], common_prefix)
|
||||
field_dict = _compute_field_dict(ops[0], common_prefix)
|
||||
msg = "node %s%s placed on device %s " % (
|
||||
ops[0].name, field_dict["defined_at"], field_dict["devices"])
|
||||
end_msg["colocations"].append(field_dict["devs_and_colocs"])
|
||||
|
@ -155,9 +155,13 @@ class CreateGraphDebugInfoDefTest(test.TestCase):
|
||||
global_op = constant_op.constant(0, name="Global").op
|
||||
op1 = constant_op.constant(1, name="One").op
|
||||
op2 = constant_op.constant(2, name="Two").op
|
||||
non_traceback_op = constant_op.constant(3, name="NonTraceback").op
|
||||
# Ensure op without traceback does not fail
|
||||
del non_traceback_op._traceback
|
||||
# pyformat: enable
|
||||
|
||||
export_ops = [("", global_op), ("func1", op1), ("func2", op2)]
|
||||
export_ops = [("", global_op), ("func1", op1), ("func2", op2),
|
||||
("func2", non_traceback_op)]
|
||||
graph_debug_info = error_interpolation.create_graph_debug_info_def(
|
||||
export_ops)
|
||||
this_file_index = -1
|
||||
@ -201,7 +205,7 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
|
||||
num_user_frames=3,
|
||||
user_filename=user_filename,
|
||||
num_inner_tf_frames=5)
|
||||
idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
|
||||
idx = error_interpolation._find_index_of_defining_frame(local_op._traceback)
|
||||
# Expected frame is 6th from the end because there are 5 inner frames witih
|
||||
# TF filenames.
|
||||
expected_frame = len(local_op._traceback) - 6
|
||||
@ -217,7 +221,7 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
|
||||
num_user_frames=0,
|
||||
user_filename="user_file.py",
|
||||
num_inner_tf_frames=7)
|
||||
idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
|
||||
idx = error_interpolation._find_index_of_defining_frame(local_op._traceback)
|
||||
self.assertEqual(0, idx)
|
||||
|
||||
def testNothingToDo(self):
|
||||
@ -264,6 +268,9 @@ class InputNodesTest(test.TestCase):
|
||||
one = constant_op.constant(1, name="One")
|
||||
two = constant_op.constant(2, name="Two")
|
||||
three = math_ops.add(one, two, name="Three")
|
||||
non_traceback_op = constant_op.constant(3, name="NonTraceback")
|
||||
# Ensure op without traceback does not fail
|
||||
del non_traceback_op.op._traceback
|
||||
self.graph = three.graph
|
||||
|
||||
def testNoInputs(self):
|
||||
|
Loading…
Reference in New Issue
Block a user