* This was encountered when attempting to save debug info for a SavedModel with a functional while loop. The Placeholder op of the while loop condition lacked a _traceback attribute, causing save to fail. * I think this is expected behavior (no traceback on synthetic ops). Even if not, I believe that having the error_interpolator be conservatively correct in the face of a missing traceback is proper. * Cleans up some protected functions to take a traceback instead of an op. * Makes compute_field_dict protected since it is not used outside of this module. PiperOrigin-RevId: 280244393 Change-Id: I3af66387c6d3af1a779c84775ec4704e2b354859
544 lines
18 KiB
Python
544 lines
18 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Function for interpolating formatted errors from the TensorFlow runtime.
|
|
|
|
Exposes the function `interpolate` to interpolate messages with tags of the form
|
|
{{type name}}.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import itertools
|
|
import os
|
|
import re
|
|
|
|
import six
|
|
|
|
from tensorflow.core.protobuf import graph_debug_info_pb2
|
|
|
|
_NAME_REGEX = r"[A-Za-z0-9_.][A-Za-z0-9_.\-/]*?"
|
|
_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX)
|
|
_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX)
|
|
_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL)
|
|
|
|
_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"])
|
|
|
|
|
|
# Remove the last three path components from this module's file (i.e.
|
|
# python/framework/error_interpolation.py) so that we have an absolute path
|
|
# prefix to the root of the installation.
|
|
_FRAMEWORK_COMMON_PREFIX = os.path.dirname(
|
|
os.path.dirname(os.path.dirname(__file__)))
|
|
|
|
# Sub-directories under the common prefix that are considered part of the
|
|
# framework.
|
|
_FRAMEWORK_PATH_PREFIXES = [
|
|
os.path.join(_FRAMEWORK_COMMON_PREFIX, "python") + os.sep,
|
|
os.path.join(_FRAMEWORK_COMMON_PREFIX, "contrib") + os.sep,
|
|
]
|
|
|
|
# Patterns of filename patterns that should be considered internal to
|
|
# the TensorFlow framework.
|
|
_FRAMEWORK_FILENAME_PATTERNS = [
|
|
re.compile(r"<embedded"),
|
|
]
|
|
|
|
# Patterns of filename patterns that should be considered external to
|
|
# TensorFlow regardless of framework prefix match.
|
|
_EXTERNAL_FILENAME_PATTERNS = [
|
|
# Explicitly treat test frames as not part of the framework.
|
|
re.compile(r"_test\.py$"),
|
|
]
|
|
|
|
|
|
def parse_message(message):
|
|
"""Parses the message.
|
|
|
|
Splits the message into separators and tags. Tags are named tuples
|
|
representing the string {{type name}} and they are separated by
|
|
separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are
|
|
two tags and three separators. The separators are the numeric characters.
|
|
|
|
Args:
|
|
message: String to parse
|
|
|
|
Returns:
|
|
(list of separator strings, list of _ParseTags).
|
|
|
|
For example, if message is "123{{node Foo}}456" then this function
|
|
returns (["123", "456"], [_ParseTag("node", "Foo")])
|
|
"""
|
|
seps = []
|
|
tags = []
|
|
pos = 0
|
|
while pos < len(message):
|
|
match = re.match(_INTERPOLATION_PATTERN, message[pos:])
|
|
if match:
|
|
seps.append(match.group(1))
|
|
tags.append(_ParseTag(match.group(3), match.group(4)))
|
|
pos += match.end()
|
|
else:
|
|
break
|
|
seps.append(message[pos:])
|
|
return seps, tags
|
|
|
|
|
|
def _compute_device_summary_from_list(name, device_assignment_list, prefix=""):
|
|
"""Return a summary of an op's device function stack.
|
|
|
|
Args:
|
|
name: The name of the op.
|
|
device_assignment_list: The op._device_assignments list.
|
|
prefix: An optional string prefix used before each line of the multi-
|
|
line string returned by this function.
|
|
|
|
Returns:
|
|
A multi-line string similar to:
|
|
Device assignments active during op 'foo' creation:
|
|
with tf.device(/cpu:0): <test_1.py:27>
|
|
with tf.device(some_func<foo.py, 123>): <test_2.py:38>
|
|
The first line will have no padding to its left by default. Subsequent
|
|
lines will have two spaces of left-padding. Use the prefix argument
|
|
to increase indentation.
|
|
"""
|
|
if not device_assignment_list:
|
|
message = "No device assignments were active during op '%s' creation."
|
|
message %= name
|
|
return prefix + message
|
|
|
|
str_list = []
|
|
str_list.append(
|
|
"%sDevice assignments active during op '%s' creation:" % (prefix, name))
|
|
|
|
for traceable_obj in device_assignment_list:
|
|
location_summary = "<{file}:{line}>".format(
|
|
file=traceable_obj.filename, line=traceable_obj.lineno)
|
|
subs = {
|
|
"prefix": prefix,
|
|
"indent": " ",
|
|
"dev_name": traceable_obj.obj,
|
|
"loc": location_summary,
|
|
}
|
|
str_list.append(
|
|
"{prefix}{indent}with tf.device({dev_name}): {loc}".format(**subs))
|
|
|
|
return "\n".join(str_list)
|
|
|
|
|
|
def _compute_device_assignment_summary_from_op(op, prefix=""):
|
|
# pylint: disable=protected-access
|
|
return _compute_device_summary_from_list(op.name, op._device_assignments,
|
|
prefix)
|
|
# pylint: enable=protected-access
|
|
|
|
|
|
def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""):
|
|
"""Return a summary of an op's colocation stack.
|
|
|
|
Args:
|
|
name: The op name.
|
|
colocation_dict: The op._colocation_dict.
|
|
prefix: An optional string prefix used before each line of the multi-
|
|
line string returned by this function.
|
|
|
|
Returns:
|
|
A multi-line string similar to:
|
|
Node-device colocations active during op creation:
|
|
with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27>
|
|
with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>
|
|
The first line will have no padding to its left by default. Subsequent
|
|
lines will have two spaces of left-padding. Use the prefix argument
|
|
to increase indentation.
|
|
"""
|
|
if not colocation_dict:
|
|
message = "No node-device colocations were active during op '%s' creation."
|
|
message %= name
|
|
return prefix + message
|
|
|
|
str_list = []
|
|
str_list.append("%sNode-device colocations active during op '%s' creation:" %
|
|
(prefix, name))
|
|
|
|
for coloc_name, location in colocation_dict.items():
|
|
location_summary = "<{file}:{line}>".format(
|
|
file=location.filename, line=location.lineno)
|
|
subs = {
|
|
"prefix": prefix,
|
|
"indent": " ",
|
|
"name": coloc_name,
|
|
"loc": location_summary,
|
|
}
|
|
str_list.append(
|
|
"{prefix}{indent}with tf.colocate_with({name}): {loc}".format(**subs))
|
|
|
|
return "\n".join(str_list)
|
|
|
|
|
|
def _compute_colocation_summary_from_op(op, prefix=""):
|
|
"""Fetch colocation file, line, and nesting and return a summary string."""
|
|
# pylint: disable=protected-access
|
|
return _compute_colocation_summary_from_dict(op.name, op._colocation_dict,
|
|
prefix)
|
|
# pylint: enable=protected-access
|
|
|
|
|
|
def _is_framework_filename(filename):
|
|
"""Returns whether a filename should be considered a part of the framework.
|
|
|
|
A file is part of the framework if it does not match a pattern in
|
|
_EXTERNAL_FILENAME_PATTERNS and it either matches a pattern in
|
|
_FRAMEWORK_FILENAME_PATTERNS or starts with a _FRAMEWORK_PATH_PREFIXES prefix.
|
|
|
|
Args:
|
|
filename: A filename string.
|
|
|
|
Returns:
|
|
Whether the filename should be considered to be internal to the
|
|
TensorFlow framework for the purposes of reporting errors.
|
|
"""
|
|
for pattern in _EXTERNAL_FILENAME_PATTERNS:
|
|
if pattern.search(filename):
|
|
return False
|
|
for pattern in _FRAMEWORK_FILENAME_PATTERNS:
|
|
if pattern.search(filename):
|
|
return True
|
|
for prefix in _FRAMEWORK_PATH_PREFIXES:
|
|
if filename.startswith(prefix):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _find_index_of_defining_frame(traceback):
|
|
"""Return index in op.traceback with first 'useful' frame.
|
|
|
|
This method reads through the stack stored in op.traceback looking for the
|
|
innermost frame which (hopefully) belongs to the caller. It accomplishes this
|
|
by rejecting frames deemed to be part of the TensorFlow framework (by
|
|
pattern matching the filename).
|
|
|
|
Args:
|
|
traceback: A list of traceback frames (as from Operation.traceback).
|
|
|
|
Returns:
|
|
Integer index into op.traceback where the first non-TF file was found
|
|
(innermost to outermost), or 0 (for the outermost stack frame) if all files
|
|
came from TensorFlow.
|
|
"""
|
|
# Index 0 of traceback is the outermost frame.
|
|
size = len(traceback)
|
|
filenames = [frame.filename for frame in traceback]
|
|
# We process the filenames from the innermost frame to outermost.
|
|
for idx, filename in enumerate(reversed(filenames)):
|
|
is_framework = _is_framework_filename(filename)
|
|
if not is_framework:
|
|
# Consider this to be the defining frame.
|
|
return size - idx - 1
|
|
return 0
|
|
|
|
|
|
def _get_defining_frame(traceback):
|
|
"""Find and return stack frame where op was defined."""
|
|
frame_index = _find_index_of_defining_frame(traceback)
|
|
return traceback[frame_index]
|
|
|
|
|
|
def _compute_useful_frames(traceback, num):
|
|
"""Return a list of frames, which form a 'useful' stack.
|
|
|
|
Starting from the defining frame to the outermost one, this method computes
|
|
the contiguous portion of the 'useful' stack trace and returns the selected
|
|
frames.
|
|
|
|
Args:
|
|
traceback: A list of traceback frames (as from Operation.traceback).
|
|
num: total number of frames to return.
|
|
|
|
Returns:
|
|
A list of frames.
|
|
"""
|
|
defining_frame_index = _find_index_of_defining_frame(traceback)
|
|
# The stack trace is collected from two lines before the defining frame in the
|
|
# model file to the outermost with `num` frames at most. These two extra lines
|
|
# are included from the TensorFlow library to give the context which node is
|
|
# defined.
|
|
innermost_excluded = min(defining_frame_index + 2 + 1, len(traceback))
|
|
outermost_included = max(innermost_excluded - num, 0)
|
|
return traceback[outermost_included:innermost_excluded]
|
|
|
|
|
|
def create_graph_debug_info_def(func_named_operations):
|
|
"""Construct and returns a `GraphDebugInfo` protocol buffer.
|
|
|
|
Args:
|
|
func_named_operations: An iterable of (func_name, op.Operation) tuples
|
|
where the Operation instances have a _traceback members. The func_name
|
|
should be the empty string for operations in the top-level Graph.
|
|
|
|
Returns:
|
|
GraphDebugInfo protocol buffer.
|
|
|
|
Raises:
|
|
TypeError: If the arguments are not of the correct proto buffer type.
|
|
"""
|
|
# Creates an empty GraphDebugInfoDef proto.
|
|
graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo()
|
|
|
|
# Gets the file names and line numbers for the exported node names. Also
|
|
# collects the unique file names.
|
|
all_file_names = set()
|
|
node_to_trace = {}
|
|
for func_name, op in func_named_operations:
|
|
try:
|
|
op_traceback = op.traceback
|
|
except AttributeError:
|
|
# Some ops synthesized on as part of function or control flow definition
|
|
# do not have tracebacks.
|
|
continue
|
|
|
|
# Gets the stack trace of the operation and then the file location.
|
|
node_name = op.name + "@" + func_name
|
|
node_to_trace[node_name] = _compute_useful_frames(op_traceback, 10)
|
|
for frame in node_to_trace[node_name]:
|
|
all_file_names.add(frame.filename)
|
|
|
|
# Sets the `files` field in the GraphDebugInfo proto
|
|
graph_debug_info_def.files.extend(all_file_names)
|
|
|
|
# Builds a mapping between file names and index of the `files` field, so we
|
|
# only store the indexes for the nodes in the GraphDebugInfo.
|
|
file_to_index = dict(
|
|
[(y, x) for x, y in enumerate(graph_debug_info_def.files)])
|
|
|
|
# Creates the FileLineCol proto for each node and sets the value in the
|
|
# GraphDebugInfo proto. We only store the file name index for each node to
|
|
# save the storage space.
|
|
for node_name, frames in node_to_trace.items():
|
|
trace_def = graph_debug_info_def.traces[node_name]
|
|
for frame in reversed(frames):
|
|
trace_def.file_line_cols.add(
|
|
file_index=file_to_index[frame.filename],
|
|
line=frame.lineno)
|
|
|
|
return graph_debug_info_def
|
|
|
|
|
|
def _compute_field_dict(op, strip_file_prefix=""):
|
|
"""Return a dictionary mapping interpolation tokens to values.
|
|
|
|
Args:
|
|
op: op.Operation object having a _traceback member.
|
|
strip_file_prefix: The common path in the stacktrace. We remove the prefix
|
|
from the file names.
|
|
|
|
Returns:
|
|
A dictionary mapping string tokens to string values. The keys are shown
|
|
below along with example values.
|
|
{
|
|
"file": "tool_utils.py",
|
|
"line": "124",
|
|
"defined_at": " (defined at tool_utils.py:124)",
|
|
"colocations":
|
|
'''Node-device colocations active during op creation:
|
|
with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27>
|
|
with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>'''
|
|
"devices":
|
|
'''Device assignments active during op 'foo' creation:
|
|
with tf.device(/cpu:0): <test_1.py:27>
|
|
with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
|
|
"devs_and_colocs": A concatenation of colocations and devices, e.g.
|
|
'''Node-device colocations active during op creation:
|
|
with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27>
|
|
with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>'''
|
|
Device assignments active during op 'foo' creation:
|
|
with tf.device(/cpu:0): <test_1.py:27>
|
|
with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
|
|
}
|
|
"""
|
|
colocation_summary = _compute_colocation_summary_from_op(op)
|
|
device_summary = _compute_device_assignment_summary_from_op(op)
|
|
combined_summary = "\n".join([colocation_summary, device_summary])
|
|
|
|
# Optional traceback info.
|
|
try:
|
|
traceback = op.traceback
|
|
except AttributeError:
|
|
# Some ops synthesized on as part of function or control flow definition
|
|
# do not have tracebacks.
|
|
filename = "<unknown>"
|
|
lineno = 0
|
|
defined_at = " (defined at <unknown>)"
|
|
else:
|
|
frame = _get_defining_frame(traceback)
|
|
filename = frame.filename
|
|
if filename.startswith(strip_file_prefix):
|
|
filename = filename[len(strip_file_prefix):]
|
|
lineno = frame.lineno
|
|
defined_at = " (defined at %s:%d)" % (filename, lineno)
|
|
|
|
field_dict = {
|
|
"colocations": colocation_summary,
|
|
"devices": device_summary,
|
|
"devs_and_colocs": combined_summary,
|
|
"defined_at": defined_at,
|
|
"file": filename,
|
|
"line": lineno,
|
|
}
|
|
return field_dict
|
|
|
|
|
|
def traceback_files_common_prefix(all_ops):
|
|
"""Determines the common prefix from the paths of the stacktrace of 'all_ops'.
|
|
|
|
For example, if the paths are '/foo/bar/baz/' and '/foo/car', this would
|
|
return '/foo'.
|
|
|
|
Args:
|
|
all_ops: All the input nodes in the form of a list of lists of ops.
|
|
|
|
Returns:
|
|
The common prefix.
|
|
"""
|
|
files = set()
|
|
for ops in all_ops:
|
|
if ops is None:
|
|
continue
|
|
for op in ops:
|
|
# TODO(slebedev): switch to .filename once 2.X support is dropped.
|
|
for filename, _, _, _ in op.traceback:
|
|
if "<embedded" not in filename:
|
|
files.add(filename)
|
|
return os.path.split(os.path.commonprefix(list(files)))[0]
|
|
|
|
|
|
def _sources_for_node(node, graph):
|
|
"""Gets the input op nodes for 'node'.
|
|
|
|
Args:
|
|
node: The node.
|
|
graph: The graph containing the node.
|
|
|
|
Returns:
|
|
The unique input nodes.
|
|
"""
|
|
inputs = set()
|
|
for name in node.node_def.input:
|
|
if name.startswith("^"):
|
|
name = name[1:]
|
|
try:
|
|
tensor = graph.get_tensor_by_name(name)
|
|
op = tensor.op
|
|
except (KeyError, ValueError):
|
|
try:
|
|
op = graph.get_operation_by_name(name)
|
|
except KeyError:
|
|
continue
|
|
inputs.add(op)
|
|
|
|
return list(inputs)
|
|
|
|
|
|
def _build_error_message(op, input_ops, common_prefix):
|
|
"""Returns the formatted error message for the given op.
|
|
|
|
Args:
|
|
op: The node.
|
|
input_ops: The input nodes to the 'op' node
|
|
common_prefix: The prefix path common to the stacktrace of inputs.
|
|
|
|
Returns:
|
|
The formatted error message for the given op. The error message also
|
|
includes the information about the input sources for the given op.
|
|
"""
|
|
field_dict = _compute_field_dict(op, common_prefix)
|
|
msg = "node %s%s " % (op.name, field_dict["defined_at"])
|
|
input_debug_info = []
|
|
# This stores the line numbers that we have already printed.
|
|
done = set()
|
|
done.add(field_dict["defined_at"])
|
|
for op_inp in input_ops:
|
|
field_dict_inp = _compute_field_dict(op_inp, common_prefix)
|
|
if field_dict_inp["defined_at"] not in done:
|
|
input_debug_info.append(
|
|
" %s%s" % (op_inp.name, field_dict_inp["defined_at"]))
|
|
done.add(field_dict_inp["defined_at"])
|
|
if input_debug_info:
|
|
end_msg = ("\nInput Source operations connected to node %s:\n") % (op.name)
|
|
end_msg += "\t\n".join(input_debug_info)
|
|
else:
|
|
end_msg = ""
|
|
return msg, end_msg
|
|
|
|
|
|
def interpolate(error_message, graph):
|
|
"""Interpolates an error message.
|
|
|
|
The error message can contain tags of the form `{{type name}}` which will be
|
|
replaced. For example: "{{node <name>}}" would get expanded to:
|
|
"node <name>(defined at <path>)".
|
|
|
|
Args:
|
|
error_message: A string to interpolate.
|
|
graph: ops.Graph object containing all nodes referenced in the error
|
|
message.
|
|
|
|
Returns:
|
|
The string with tags of the form {{type name}} interpolated.
|
|
"""
|
|
seps, tags = parse_message(error_message)
|
|
subs = []
|
|
end_msg = collections.defaultdict(list)
|
|
tagged_ops = []
|
|
|
|
for t in tags:
|
|
try:
|
|
op = graph.get_operation_by_name(t.name)
|
|
except KeyError:
|
|
op = None
|
|
if op is None:
|
|
tagged_ops.append(None)
|
|
else:
|
|
tagged_ops.append([op] + _sources_for_node(op, graph))
|
|
|
|
common_prefix = traceback_files_common_prefix(tagged_ops)
|
|
for tag, ops in zip(tags, tagged_ops):
|
|
msg = "{{%s %s}}" % (tag.type, tag.name)
|
|
if ops is not None:
|
|
if tag.type == "node":
|
|
msg, source_msg = _build_error_message(ops[0], ops[1:], common_prefix)
|
|
if source_msg:
|
|
end_msg["source_nodes"].append(source_msg)
|
|
elif tag.type == "colocation_node":
|
|
field_dict = _compute_field_dict(ops[0], common_prefix)
|
|
msg = "node %s%s placed on device %s " % (
|
|
ops[0].name, field_dict["defined_at"], field_dict["devices"])
|
|
end_msg["colocations"].append(field_dict["devs_and_colocs"])
|
|
if tag.type == "function_node":
|
|
msg = ""
|
|
subs.append(msg)
|
|
|
|
if "source_nodes" in end_msg:
|
|
subs.append("\n\nErrors may have originated from an input operation.")
|
|
subs.append("\n".join(end_msg["source_nodes"]))
|
|
end_msg.pop("source_nodes", None)
|
|
for k, messages in end_msg.items():
|
|
subs.append("Additional information about %s:" % k)
|
|
subs.append("\n".join(messages))
|
|
|
|
return "".join(
|
|
itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue="")))
|