diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py index 0c6feb64008..77a2d13784a 100644 --- a/tensorflow/python/framework/error_interpolation.py +++ b/tensorflow/python/framework/error_interpolation.py @@ -223,7 +223,7 @@ def _is_framework_filename(filename): return False -def _find_index_of_defining_frame_for_op(op): +def _find_index_of_defining_frame(traceback): """Return index in op.traceback with first 'useful' frame. This method reads through the stack stored in op.traceback looking for the @@ -232,18 +232,16 @@ def _find_index_of_defining_frame_for_op(op): pattern matching the filename). Args: - op: the Operation object for which we would like to find the defining - location. + traceback: A list of traceback frames (as from Operation.traceback). Returns: Integer index into op.traceback where the first non-TF file was found (innermost to outermost), or 0 (for the outermost stack frame) if all files came from TensorFlow. """ - # Index 0 of tf_traceback is the outermost frame. - tf_traceback = op.traceback - size = len(tf_traceback) - filenames = [frame.filename for frame in tf_traceback] + # Index 0 of traceback is the outermost frame. + size = len(traceback) + filenames = [frame.filename for frame in traceback] # We process the filenames from the innermost frame to outermost. for idx, filename in enumerate(reversed(filenames)): is_framework = _is_framework_filename(filename) @@ -253,13 +251,13 @@ def _find_index_of_defining_frame_for_op(op): return 0 -def _get_defining_frame_from_op(op): +def _get_defining_frame(traceback): """Find and return stack frame where op was defined.""" - frame_index = _find_index_of_defining_frame_for_op(op) - return op.traceback[frame_index] + frame_index = _find_index_of_defining_frame(traceback) + return traceback[frame_index] -def _compute_useful_frames(op, num): +def _compute_useful_frames(traceback, num): """Return a list of frames, which form a 'useful' stack. Starting from the defining frame to the outermost one, this method computes @@ -267,20 +265,20 @@ def _compute_useful_frames(op, num): frames. Args: - op: op.Operation object having a _traceback member. + traceback: A list of traceback frames (as from Operation.traceback). num: total number of frames to return. Returns: A list of frames. """ - defining_frame_index = _find_index_of_defining_frame_for_op(op) + defining_frame_index = _find_index_of_defining_frame(traceback) # The stack trace is collected from two lines before the defining frame in the # model file to the outermost with `num` frames at most. These two extra lines # are included from the TensorFlow library to give the context which node is # defined. - innermost_excluded = min(defining_frame_index + 2 + 1, len(op.traceback)) + innermost_excluded = min(defining_frame_index + 2 + 1, len(traceback)) outermost_included = max(innermost_excluded - num, 0) - return op.traceback[outermost_included:innermost_excluded] + return traceback[outermost_included:innermost_excluded] def create_graph_debug_info_def(func_named_operations): @@ -305,9 +303,16 @@ def create_graph_debug_info_def(func_named_operations): all_file_names = set() node_to_trace = {} for func_name, op in func_named_operations: + try: + op_traceback = op.traceback + except AttributeError: + # Some ops synthesized on as part of function or control flow definition + # do not have tracebacks. + continue + # Gets the stack trace of the operation and then the file location. node_name = op.name + "@" + func_name - node_to_trace[node_name] = _compute_useful_frames(op, 10) + node_to_trace[node_name] = _compute_useful_frames(op_traceback, 10) for frame in node_to_trace[node_name]: all_file_names.add(frame.filename) @@ -332,7 +337,7 @@ def create_graph_debug_info_def(func_named_operations): return graph_debug_info_def -def compute_field_dict(op, strip_file_prefix=""): +def _compute_field_dict(op, strip_file_prefix=""): """Return a dictionary mapping interpolation tokens to values. Args: @@ -364,23 +369,34 @@ def compute_field_dict(op, strip_file_prefix=""): with tf.device(some_func): ''' } """ - frame = _get_defining_frame_from_op(op) - filename = frame.filename - if filename.startswith(strip_file_prefix): - filename = filename[len(strip_file_prefix):] - lineno = frame.lineno - defined_at = " (defined at %s:%d)" % (filename, lineno) colocation_summary = _compute_colocation_summary_from_op(op) device_summary = _compute_device_assignment_summary_from_op(op) combined_summary = "\n".join([colocation_summary, device_summary]) + # Optional traceback info. + try: + traceback = op.traceback + except AttributeError: + # Some ops synthesized on as part of function or control flow definition + # do not have tracebacks. + filename = "" + lineno = 0 + defined_at = " (defined at )" + else: + frame = _get_defining_frame(traceback) + filename = frame.filename + if filename.startswith(strip_file_prefix): + filename = filename[len(strip_file_prefix):] + lineno = frame.lineno + defined_at = " (defined at %s:%d)" % (filename, lineno) + field_dict = { - "file": filename, - "line": lineno, - "defined_at": defined_at, "colocations": colocation_summary, "devices": device_summary, "devs_and_colocs": combined_summary, + "defined_at": defined_at, + "file": filename, + "line": lineno, } return field_dict @@ -448,14 +464,14 @@ def _build_error_message(op, input_ops, common_prefix): The formatted error message for the given op. The error message also includes the information about the input sources for the given op. """ - field_dict = compute_field_dict(op, common_prefix) + field_dict = _compute_field_dict(op, common_prefix) msg = "node %s%s " % (op.name, field_dict["defined_at"]) input_debug_info = [] # This stores the line numbers that we have already printed. done = set() done.add(field_dict["defined_at"]) for op_inp in input_ops: - field_dict_inp = compute_field_dict(op_inp, common_prefix) + field_dict_inp = _compute_field_dict(op_inp, common_prefix) if field_dict_inp["defined_at"] not in done: input_debug_info.append( " %s%s" % (op_inp.name, field_dict_inp["defined_at"])) @@ -507,7 +523,7 @@ def interpolate(error_message, graph): if source_msg: end_msg["source_nodes"].append(source_msg) elif tag.type == "colocation_node": - field_dict = compute_field_dict(ops[0], common_prefix) + field_dict = _compute_field_dict(ops[0], common_prefix) msg = "node %s%s placed on device %s " % ( ops[0].name, field_dict["defined_at"], field_dict["devices"]) end_msg["colocations"].append(field_dict["devs_and_colocs"]) diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py index 3cf5a9288b9..4e6027373cb 100644 --- a/tensorflow/python/framework/error_interpolation_test.py +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -155,9 +155,13 @@ class CreateGraphDebugInfoDefTest(test.TestCase): global_op = constant_op.constant(0, name="Global").op op1 = constant_op.constant(1, name="One").op op2 = constant_op.constant(2, name="Two").op + non_traceback_op = constant_op.constant(3, name="NonTraceback").op + # Ensure op without traceback does not fail + del non_traceback_op._traceback # pyformat: enable - export_ops = [("", global_op), ("func1", op1), ("func2", op2)] + export_ops = [("", global_op), ("func1", op1), ("func2", op2), + ("func2", non_traceback_op)] graph_debug_info = error_interpolation.create_graph_debug_info_def( export_ops) this_file_index = -1 @@ -201,7 +205,7 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase): num_user_frames=3, user_filename=user_filename, num_inner_tf_frames=5) - idx = error_interpolation._find_index_of_defining_frame_for_op(local_op) + idx = error_interpolation._find_index_of_defining_frame(local_op._traceback) # Expected frame is 6th from the end because there are 5 inner frames witih # TF filenames. expected_frame = len(local_op._traceback) - 6 @@ -217,7 +221,7 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase): num_user_frames=0, user_filename="user_file.py", num_inner_tf_frames=7) - idx = error_interpolation._find_index_of_defining_frame_for_op(local_op) + idx = error_interpolation._find_index_of_defining_frame(local_op._traceback) self.assertEqual(0, idx) def testNothingToDo(self): @@ -264,6 +268,9 @@ class InputNodesTest(test.TestCase): one = constant_op.constant(1, name="One") two = constant_op.constant(2, name="Two") three = math_ops.add(one, two, name="Three") + non_traceback_op = constant_op.constant(3, name="NonTraceback") + # Ensure op without traceback does not fail + del non_traceback_op.op._traceback self.graph = three.graph def testNoInputs(self):