Use StackFrame of node during import

Copy the StackFrames over post copying the Graph (we need to remove that still)

PiperOrigin-RevId: 352904629
Change-Id: I36dadc7171e26b00726dfc0863d719a4746f40c6
This commit is contained in:
Jacques Pienaar 2021-01-20 16:51:21 -08:00 committed by TensorFlower Gardener
parent bc2269a05e
commit 03acd8ec0b
3 changed files with 83 additions and 30 deletions

View File

@ -377,7 +377,7 @@ class ImporterBase {
// there are multiple "original_node_names", a FusedLoc is returned. If the
// node name couldn't be found in the input DebugInfo, a NameLoc is used as
// the location.
mlir::Location GetLocation(const NodeDef& node);
mlir::Location GetLocation(const Node& node);
// Appends the location string for the node to the error message and returns
// the combined error status.
@ -598,6 +598,31 @@ Status ImporterBase::RemoveBackedges(const Graph& graph) {
GetReversePostOrder(
*graph_, &ordered_nodes_,
[](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
return Status::OK();
}
Status CopyStackTraces(const Graph& from, Graph* to) {
// Copy over the stack traces.
// TODO(jpienaar): This really shouldn't be needed, copying the Graph above
// and then needing these traversals is unfortunate.
std::unordered_map<string, Node*> node_map = from.BuildNodeNameIndex();
for (Node* node : to->nodes()) {
if (const Node* old_node = node_map[node->name()]) {
if (const std::shared_ptr<AbstractStackTrace>& stack =
old_node->GetStackTrace()) {
DVLOG(2) << "Stack for " << node->name() << " "
<< old_node->GetStackTrace()->ToString(
AbstractStackTrace::TracePrintingOptions());
node->SetStackTrace(stack);
} else {
DVLOG(1) << "No stack for " << node->name() << " (" << node
<< ") in Graph " << &from;
}
} else {
DVLOG(1) << "No stack for " << node->name() << " (" << node
<< ") in Graph " << &from;
}
}
return Status::OK();
}
@ -1385,6 +1410,7 @@ Status ImporterBase::ConvertFeedsToPlaceholders(
Status ImporterBase::PrepareConvert(const Graph& graph) {
TF_RETURN_IF_ERROR(RemoveBackedges(graph));
TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get()));
auto node_name_map = graph_->BuildNodeNameIndex();
@ -1579,7 +1605,8 @@ Status ImporterBase::ConvertFunctionArgAndRets(
return Status::OK();
}
mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
mlir::Location ImporterBase::GetLocation(const Node& node) {
DVLOG(1) << "Getting location for " << node.name() << " " << &node;
// TODO(b/142400497): What is the semantic contract for locations?
const auto& debug_info = debug_info_.traces();
@ -1599,21 +1626,37 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
std::string name_for_name_loc =
function_name.empty() ? name.str() : (name + "@" + function_name).str();
auto name_loc_id = mlir::Identifier::get(name_for_name_loc, context_);
const auto location_it = debug_info.find(debug_info_key);
if (location_it == debug_info.end()) {
return mlir::NameLoc::get(name_loc_id, context_);
}
// Convert the stack trace to a chain of mlir::CallSiteLocs.
const auto& trace = location_it->second;
llvm::SmallVector<mlir::Location, 4> locations;
locations.reserve(trace.file_line_cols_size());
for (const auto& location : trace.file_line_cols()) {
const auto& file = debug_info_.files(location.file_index());
auto file_name = mlir::Identifier::get(file, context_);
auto file_line_loc = mlir::FileLineColLoc::get(file_name, location.line(),
location.col(), context_);
locations.push_back(file_line_loc);
// Prefer stack traces if available, fallback to debug info if not, and then
// finally to just name.
if (auto stack_trace = node.GetStackTrace()) {
DVLOG(1) << "Stack available for " << node.name();
absl::Span<const StackFrame> frames = stack_trace->ToFrames();
locations.reserve(frames.size());
for (const StackFrame& frame : llvm::reverse(frames)) {
auto file_name = mlir::Identifier::get(frame.file_name, context_);
// Use col 1 as there is no column info in StackTrace.
auto file_line_loc = mlir::FileLineColLoc::get(
file_name, frame.line_number, 1, context_);
locations.push_back(file_line_loc);
}
} else {
DVLOG(1) << "No stack trace for " << node.name();
const auto location_it = debug_info.find(debug_info_key);
if (location_it != debug_info.end()) {
DVLOG(1) << "Available serialized debug info for " << node.name();
// Convert the stack trace to a chain of mlir::CallSiteLocs.
const auto& trace = location_it->second;
locations.reserve(trace.file_line_cols_size());
for (const auto& location : trace.file_line_cols()) {
const auto& file = debug_info_.files(location.file_index());
auto file_name = mlir::Identifier::get(file, context_);
auto file_line_loc = mlir::FileLineColLoc::get(
file_name, location.line(), location.col(), context_);
locations.push_back(file_line_loc);
}
}
}
// If there are no locations in the stack trace, fall back to just a
@ -1636,16 +1679,20 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
// Hence, we use node name as location to keep it unique.
// TODO(prakalps): In future the plan is to use tokens to pair source/sink
// nodes. Then NextIteration nodes would not need to be handled separately.
if (node_def.op() == "NextIteration")
return create_location(node_def.name(), function_name_for_debug_info_);
if (node.type_string() == "NextIteration")
return create_location(node.name(), function_name_for_debug_info_);
if (node.GetStackTrace())
return create_location(node.name(), function_name_for_debug_info_);
const auto& node_def = node.def();
auto original_nodes =
node_def.experimental_debug_info().original_node_names();
auto original_funcs =
node_def.experimental_debug_info().original_func_names();
if (original_nodes.empty()) {
return create_location(node_def.name(), function_name_for_debug_info_);
return create_location(node.name(), function_name_for_debug_info_);
} else {
// If the original nodes are defined, then we use them to get a list of
// call sites, and then fuse them to a single fused location, with the name
@ -1661,14 +1708,14 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
}
// store the name of the node_def
node_locations.push_back(
create_location(node_def.name(), function_name_for_debug_info_));
create_location(node.name(), function_name_for_debug_info_));
return mlir::FusedLoc::get(node_locations, context_);
}
}
Status ImporterBase::EmitErrorWithLocationStr(const Node& node,
const Status& error_status) {
const mlir::Location location = GetLocation(node.def());
const mlir::Location location = GetLocation(node);
mlir::emitError(location);
return error_handler_.Combine(error_status);
}
@ -1832,8 +1879,7 @@ Status ImporterBase::ConvertNode(const Node& node) {
op_name = op_name + ".sink";
}
const auto& node_def = node.def();
mlir::OperationState result(GetLocation(node_def), op_name);
mlir::OperationState result(GetLocation(node), op_name);
for (int i = 0; i < node.num_outputs(); ++i) {
// The backedge has been removed, so we shouldn't count the corresponding
// output from the src node when converting to an operation.
@ -1937,6 +1983,7 @@ Status ImporterBase::ConvertNode(const Node& node) {
&result.attributes));
}
const auto& node_def = node.def();
result.attributes.push_back(builder_.getNamedAttr(
"device", builder_.getStringAttr(std::string(node_def.device()))));

View File

@ -55,7 +55,7 @@ class MLIRConcreteFunctionImportTest(test.TestCase):
tensor_spec.TensorSpec(None, dtypes.float32))
mlir_module = mlir.convert_function(concrete_function, show_debug_info=True)
self.assertRegex(mlir_module, r'func @.*sqr.*\(')
self.assertRegex(mlir_module, r'loc\(')
self.assertRegex(mlir_module, r'callsite\(".*mlir_test.py":')
@test_util.run_v2_only
def testImportWithCall(self):

View File

@ -871,23 +871,29 @@ class DefFunctionTest(xla_test.XLATestCase):
f(constant_op.constant(1))
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
'support stack traces')
def testTensorArrayErrorMessage(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(jit_compile=True)
def f():
ta = tensor_array_ops.TensorArray(
# The error message as old and new bridge differ in which op they flag.
# The one points to the creation of the unitialized tensor array, the
# other is the use of the unitialized tensor array.
ta = tensor_array_ops.TensorArray( # EXPECTED_MESSAGE_NEW
dtype=dtypes.float32,
size=2,
dynamic_size=True,
element_shape=(None,))
return ta.concat() # EXPECTED_MESSAGE
return ta.concat() # EXPECTED_MESSAGE_OLD
with self.assertRaisesRegex(errors.InvalidArgumentError,
'EXPECTED_MESSAGE'):
f()
if test_util.is_mlir_bridge_enabled():
with self.assertRaisesRegex(errors.InternalError,
'EXPECTED_MESSAGE_NEW'):
f()
else:
with self.assertRaisesRegex(errors.InvalidArgumentError,
'EXPECTED_MESSAGE_OLD'):
f()
if __name__ == '__main__':