Use StackFrame of node during import
Copy the StackFrames over post copying the Graph (we need to remove that still) PiperOrigin-RevId: 352904629 Change-Id: I36dadc7171e26b00726dfc0863d719a4746f40c6
This commit is contained in:
parent
bc2269a05e
commit
03acd8ec0b
@ -377,7 +377,7 @@ class ImporterBase {
|
||||
// there are multiple "original_node_names", a FusedLoc is returned. If the
|
||||
// node name couldn't be found in the input DebugInfo, a NameLoc is used as
|
||||
// the location.
|
||||
mlir::Location GetLocation(const NodeDef& node);
|
||||
mlir::Location GetLocation(const Node& node);
|
||||
|
||||
// Appends the location string for the node to the error message and returns
|
||||
// the combined error status.
|
||||
@ -598,6 +598,31 @@ Status ImporterBase::RemoveBackedges(const Graph& graph) {
|
||||
GetReversePostOrder(
|
||||
*graph_, &ordered_nodes_,
|
||||
[](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CopyStackTraces(const Graph& from, Graph* to) {
|
||||
// Copy over the stack traces.
|
||||
// TODO(jpienaar): This really shouldn't be needed, copying the Graph above
|
||||
// and then needing these traversals is unfortunate.
|
||||
std::unordered_map<string, Node*> node_map = from.BuildNodeNameIndex();
|
||||
for (Node* node : to->nodes()) {
|
||||
if (const Node* old_node = node_map[node->name()]) {
|
||||
if (const std::shared_ptr<AbstractStackTrace>& stack =
|
||||
old_node->GetStackTrace()) {
|
||||
DVLOG(2) << "Stack for " << node->name() << " "
|
||||
<< old_node->GetStackTrace()->ToString(
|
||||
AbstractStackTrace::TracePrintingOptions());
|
||||
node->SetStackTrace(stack);
|
||||
} else {
|
||||
DVLOG(1) << "No stack for " << node->name() << " (" << node
|
||||
<< ") in Graph " << &from;
|
||||
}
|
||||
} else {
|
||||
DVLOG(1) << "No stack for " << node->name() << " (" << node
|
||||
<< ") in Graph " << &from;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1385,6 +1410,7 @@ Status ImporterBase::ConvertFeedsToPlaceholders(
|
||||
|
||||
Status ImporterBase::PrepareConvert(const Graph& graph) {
|
||||
TF_RETURN_IF_ERROR(RemoveBackedges(graph));
|
||||
TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get()));
|
||||
|
||||
auto node_name_map = graph_->BuildNodeNameIndex();
|
||||
|
||||
@ -1579,7 +1605,8 @@ Status ImporterBase::ConvertFunctionArgAndRets(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
|
||||
mlir::Location ImporterBase::GetLocation(const Node& node) {
|
||||
DVLOG(1) << "Getting location for " << node.name() << " " << &node;
|
||||
// TODO(b/142400497): What is the semantic contract for locations?
|
||||
const auto& debug_info = debug_info_.traces();
|
||||
|
||||
@ -1599,21 +1626,37 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
|
||||
std::string name_for_name_loc =
|
||||
function_name.empty() ? name.str() : (name + "@" + function_name).str();
|
||||
auto name_loc_id = mlir::Identifier::get(name_for_name_loc, context_);
|
||||
const auto location_it = debug_info.find(debug_info_key);
|
||||
if (location_it == debug_info.end()) {
|
||||
return mlir::NameLoc::get(name_loc_id, context_);
|
||||
}
|
||||
|
||||
// Convert the stack trace to a chain of mlir::CallSiteLocs.
|
||||
const auto& trace = location_it->second;
|
||||
llvm::SmallVector<mlir::Location, 4> locations;
|
||||
locations.reserve(trace.file_line_cols_size());
|
||||
for (const auto& location : trace.file_line_cols()) {
|
||||
const auto& file = debug_info_.files(location.file_index());
|
||||
auto file_name = mlir::Identifier::get(file, context_);
|
||||
auto file_line_loc = mlir::FileLineColLoc::get(file_name, location.line(),
|
||||
location.col(), context_);
|
||||
locations.push_back(file_line_loc);
|
||||
// Prefer stack traces if available, fallback to debug info if not, and then
|
||||
// finally to just name.
|
||||
if (auto stack_trace = node.GetStackTrace()) {
|
||||
DVLOG(1) << "Stack available for " << node.name();
|
||||
absl::Span<const StackFrame> frames = stack_trace->ToFrames();
|
||||
locations.reserve(frames.size());
|
||||
for (const StackFrame& frame : llvm::reverse(frames)) {
|
||||
auto file_name = mlir::Identifier::get(frame.file_name, context_);
|
||||
// Use col 1 as there is no column info in StackTrace.
|
||||
auto file_line_loc = mlir::FileLineColLoc::get(
|
||||
file_name, frame.line_number, 1, context_);
|
||||
locations.push_back(file_line_loc);
|
||||
}
|
||||
} else {
|
||||
DVLOG(1) << "No stack trace for " << node.name();
|
||||
const auto location_it = debug_info.find(debug_info_key);
|
||||
if (location_it != debug_info.end()) {
|
||||
DVLOG(1) << "Available serialized debug info for " << node.name();
|
||||
// Convert the stack trace to a chain of mlir::CallSiteLocs.
|
||||
const auto& trace = location_it->second;
|
||||
locations.reserve(trace.file_line_cols_size());
|
||||
for (const auto& location : trace.file_line_cols()) {
|
||||
const auto& file = debug_info_.files(location.file_index());
|
||||
auto file_name = mlir::Identifier::get(file, context_);
|
||||
auto file_line_loc = mlir::FileLineColLoc::get(
|
||||
file_name, location.line(), location.col(), context_);
|
||||
locations.push_back(file_line_loc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no locations in the stack trace, fall back to just a
|
||||
@ -1636,16 +1679,20 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
|
||||
// Hence, we use node name as location to keep it unique.
|
||||
// TODO(prakalps): In future the plan is to use tokens to pair source/sink
|
||||
// nodes. Then NextIteration nodes would not need to be handled separately.
|
||||
if (node_def.op() == "NextIteration")
|
||||
return create_location(node_def.name(), function_name_for_debug_info_);
|
||||
if (node.type_string() == "NextIteration")
|
||||
return create_location(node.name(), function_name_for_debug_info_);
|
||||
|
||||
if (node.GetStackTrace())
|
||||
return create_location(node.name(), function_name_for_debug_info_);
|
||||
|
||||
const auto& node_def = node.def();
|
||||
auto original_nodes =
|
||||
node_def.experimental_debug_info().original_node_names();
|
||||
auto original_funcs =
|
||||
node_def.experimental_debug_info().original_func_names();
|
||||
|
||||
if (original_nodes.empty()) {
|
||||
return create_location(node_def.name(), function_name_for_debug_info_);
|
||||
return create_location(node.name(), function_name_for_debug_info_);
|
||||
} else {
|
||||
// If the original nodes are defined, then we use them to get a list of
|
||||
// call sites, and then fuse them to a single fused location, with the name
|
||||
@ -1661,14 +1708,14 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
|
||||
}
|
||||
// store the name of the node_def
|
||||
node_locations.push_back(
|
||||
create_location(node_def.name(), function_name_for_debug_info_));
|
||||
create_location(node.name(), function_name_for_debug_info_));
|
||||
return mlir::FusedLoc::get(node_locations, context_);
|
||||
}
|
||||
}
|
||||
|
||||
Status ImporterBase::EmitErrorWithLocationStr(const Node& node,
|
||||
const Status& error_status) {
|
||||
const mlir::Location location = GetLocation(node.def());
|
||||
const mlir::Location location = GetLocation(node);
|
||||
mlir::emitError(location);
|
||||
return error_handler_.Combine(error_status);
|
||||
}
|
||||
@ -1832,8 +1879,7 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
||||
op_name = op_name + ".sink";
|
||||
}
|
||||
|
||||
const auto& node_def = node.def();
|
||||
mlir::OperationState result(GetLocation(node_def), op_name);
|
||||
mlir::OperationState result(GetLocation(node), op_name);
|
||||
for (int i = 0; i < node.num_outputs(); ++i) {
|
||||
// The backedge has been removed, so we shouldn't count the corresponding
|
||||
// output from the src node when converting to an operation.
|
||||
@ -1937,6 +1983,7 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
||||
&result.attributes));
|
||||
}
|
||||
|
||||
const auto& node_def = node.def();
|
||||
result.attributes.push_back(builder_.getNamedAttr(
|
||||
"device", builder_.getStringAttr(std::string(node_def.device()))));
|
||||
|
||||
|
@ -55,7 +55,7 @@ class MLIRConcreteFunctionImportTest(test.TestCase):
|
||||
tensor_spec.TensorSpec(None, dtypes.float32))
|
||||
mlir_module = mlir.convert_function(concrete_function, show_debug_info=True)
|
||||
self.assertRegex(mlir_module, r'func @.*sqr.*\(')
|
||||
self.assertRegex(mlir_module, r'loc\(')
|
||||
self.assertRegex(mlir_module, r'callsite\(".*mlir_test.py":')
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testImportWithCall(self):
|
||||
|
@ -871,23 +871,29 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
|
||||
f(constant_op.constant(1))
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||
'support stack traces')
|
||||
def testTensorArrayErrorMessage(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@def_function.function(jit_compile=True)
|
||||
def f():
|
||||
ta = tensor_array_ops.TensorArray(
|
||||
# The error message as old and new bridge differ in which op they flag.
|
||||
# The one points to the creation of the unitialized tensor array, the
|
||||
# other is the use of the unitialized tensor array.
|
||||
ta = tensor_array_ops.TensorArray( # EXPECTED_MESSAGE_NEW
|
||||
dtype=dtypes.float32,
|
||||
size=2,
|
||||
dynamic_size=True,
|
||||
element_shape=(None,))
|
||||
return ta.concat() # EXPECTED_MESSAGE
|
||||
return ta.concat() # EXPECTED_MESSAGE_OLD
|
||||
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'EXPECTED_MESSAGE'):
|
||||
f()
|
||||
if test_util.is_mlir_bridge_enabled():
|
||||
with self.assertRaisesRegex(errors.InternalError,
|
||||
'EXPECTED_MESSAGE_NEW'):
|
||||
f()
|
||||
else:
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'EXPECTED_MESSAGE_OLD'):
|
||||
f()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user