diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 6735449a2b8..912fd7f0482 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -377,7 +377,7 @@ class ImporterBase { // there are multiple "original_node_names", a FusedLoc is returned. If the // node name couldn't be found in the input DebugInfo, a NameLoc is used as // the location. - mlir::Location GetLocation(const NodeDef& node); + mlir::Location GetLocation(const Node& node); // Appends the location string for the node to the error message and returns // the combined error status. @@ -598,6 +598,31 @@ Status ImporterBase::RemoveBackedges(const Graph& graph) { GetReversePostOrder( *graph_, &ordered_nodes_, [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); }); + return Status::OK(); +} + +Status CopyStackTraces(const Graph& from, Graph* to) { + // Copy over the stack traces. + // TODO(jpienaar): This really shouldn't be needed, copying the Graph above + // and then needing these traversals is unfortunate. + std::unordered_map node_map = from.BuildNodeNameIndex(); + for (Node* node : to->nodes()) { + if (const Node* old_node = node_map[node->name()]) { + if (const std::shared_ptr& stack = + old_node->GetStackTrace()) { + DVLOG(2) << "Stack for " << node->name() << " " + << old_node->GetStackTrace()->ToString( + AbstractStackTrace::TracePrintingOptions()); + node->SetStackTrace(stack); + } else { + DVLOG(1) << "No stack for " << node->name() << " (" << node + << ") in Graph " << &from; + } + } else { + DVLOG(1) << "No stack for " << node->name() << " (" << node + << ") in Graph " << &from; + } + } return Status::OK(); } @@ -1385,6 +1410,7 @@ Status ImporterBase::ConvertFeedsToPlaceholders( Status ImporterBase::PrepareConvert(const Graph& graph) { TF_RETURN_IF_ERROR(RemoveBackedges(graph)); + TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get())); auto node_name_map = graph_->BuildNodeNameIndex(); @@ -1579,7 +1605,8 @@ Status ImporterBase::ConvertFunctionArgAndRets( return Status::OK(); } -mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) { +mlir::Location ImporterBase::GetLocation(const Node& node) { + DVLOG(1) << "Getting location for " << node.name() << " " << &node; // TODO(b/142400497): What is the semantic contract for locations? const auto& debug_info = debug_info_.traces(); @@ -1599,21 +1626,37 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) { std::string name_for_name_loc = function_name.empty() ? name.str() : (name + "@" + function_name).str(); auto name_loc_id = mlir::Identifier::get(name_for_name_loc, context_); - const auto location_it = debug_info.find(debug_info_key); - if (location_it == debug_info.end()) { - return mlir::NameLoc::get(name_loc_id, context_); - } - // Convert the stack trace to a chain of mlir::CallSiteLocs. - const auto& trace = location_it->second; llvm::SmallVector locations; - locations.reserve(trace.file_line_cols_size()); - for (const auto& location : trace.file_line_cols()) { - const auto& file = debug_info_.files(location.file_index()); - auto file_name = mlir::Identifier::get(file, context_); - auto file_line_loc = mlir::FileLineColLoc::get(file_name, location.line(), - location.col(), context_); - locations.push_back(file_line_loc); + // Prefer stack traces if available, fallback to debug info if not, and then + // finally to just name. + if (auto stack_trace = node.GetStackTrace()) { + DVLOG(1) << "Stack available for " << node.name(); + absl::Span frames = stack_trace->ToFrames(); + locations.reserve(frames.size()); + for (const StackFrame& frame : llvm::reverse(frames)) { + auto file_name = mlir::Identifier::get(frame.file_name, context_); + // Use col 1 as there is no column info in StackTrace. + auto file_line_loc = mlir::FileLineColLoc::get( + file_name, frame.line_number, 1, context_); + locations.push_back(file_line_loc); + } + } else { + DVLOG(1) << "No stack trace for " << node.name(); + const auto location_it = debug_info.find(debug_info_key); + if (location_it != debug_info.end()) { + DVLOG(1) << "Available serialized debug info for " << node.name(); + // Convert the stack trace to a chain of mlir::CallSiteLocs. + const auto& trace = location_it->second; + locations.reserve(trace.file_line_cols_size()); + for (const auto& location : trace.file_line_cols()) { + const auto& file = debug_info_.files(location.file_index()); + auto file_name = mlir::Identifier::get(file, context_); + auto file_line_loc = mlir::FileLineColLoc::get( + file_name, location.line(), location.col(), context_); + locations.push_back(file_line_loc); + } + } } // If there are no locations in the stack trace, fall back to just a @@ -1636,16 +1679,20 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) { // Hence, we use node name as location to keep it unique. // TODO(prakalps): In future the plan is to use tokens to pair source/sink // nodes. Then NextIteration nodes would not need to be handled separately. - if (node_def.op() == "NextIteration") - return create_location(node_def.name(), function_name_for_debug_info_); + if (node.type_string() == "NextIteration") + return create_location(node.name(), function_name_for_debug_info_); + if (node.GetStackTrace()) + return create_location(node.name(), function_name_for_debug_info_); + + const auto& node_def = node.def(); auto original_nodes = node_def.experimental_debug_info().original_node_names(); auto original_funcs = node_def.experimental_debug_info().original_func_names(); if (original_nodes.empty()) { - return create_location(node_def.name(), function_name_for_debug_info_); + return create_location(node.name(), function_name_for_debug_info_); } else { // If the original nodes are defined, then we use them to get a list of // call sites, and then fuse them to a single fused location, with the name @@ -1661,14 +1708,14 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) { } // store the name of the node_def node_locations.push_back( - create_location(node_def.name(), function_name_for_debug_info_)); + create_location(node.name(), function_name_for_debug_info_)); return mlir::FusedLoc::get(node_locations, context_); } } Status ImporterBase::EmitErrorWithLocationStr(const Node& node, const Status& error_status) { - const mlir::Location location = GetLocation(node.def()); + const mlir::Location location = GetLocation(node); mlir::emitError(location); return error_handler_.Combine(error_status); } @@ -1832,8 +1879,7 @@ Status ImporterBase::ConvertNode(const Node& node) { op_name = op_name + ".sink"; } - const auto& node_def = node.def(); - mlir::OperationState result(GetLocation(node_def), op_name); + mlir::OperationState result(GetLocation(node), op_name); for (int i = 0; i < node.num_outputs(); ++i) { // The backedge has been removed, so we shouldn't count the corresponding // output from the src node when converting to an operation. @@ -1937,6 +1983,7 @@ Status ImporterBase::ConvertNode(const Node& node) { &result.attributes)); } + const auto& node_def = node.def(); result.attributes.push_back(builder_.getNamedAttr( "device", builder_.getStringAttr(std::string(node_def.device())))); diff --git a/tensorflow/python/compiler/mlir/mlir_test.py b/tensorflow/python/compiler/mlir/mlir_test.py index cb8ec5def35..adce6b12542 100644 --- a/tensorflow/python/compiler/mlir/mlir_test.py +++ b/tensorflow/python/compiler/mlir/mlir_test.py @@ -55,7 +55,7 @@ class MLIRConcreteFunctionImportTest(test.TestCase): tensor_spec.TensorSpec(None, dtypes.float32)) mlir_module = mlir.convert_function(concrete_function, show_debug_info=True) self.assertRegex(mlir_module, r'func @.*sqr.*\(') - self.assertRegex(mlir_module, r'loc\(') + self.assertRegex(mlir_module, r'callsite\(".*mlir_test.py":') @test_util.run_v2_only def testImportWithCall(self): diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index dfed13a2669..fe557bd434e 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -871,23 +871,29 @@ class DefFunctionTest(xla_test.XLATestCase): f(constant_op.constant(1)) - @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' - 'support stack traces') def testTensorArrayErrorMessage(self): with ops.device('device:{}:0'.format(self.device)): @def_function.function(jit_compile=True) def f(): - ta = tensor_array_ops.TensorArray( + # The error message as old and new bridge differ in which op they flag. + # The one points to the creation of the unitialized tensor array, the + # other is the use of the unitialized tensor array. + ta = tensor_array_ops.TensorArray( # EXPECTED_MESSAGE_NEW dtype=dtypes.float32, size=2, dynamic_size=True, element_shape=(None,)) - return ta.concat() # EXPECTED_MESSAGE + return ta.concat() # EXPECTED_MESSAGE_OLD - with self.assertRaisesRegex(errors.InvalidArgumentError, - 'EXPECTED_MESSAGE'): - f() + if test_util.is_mlir_bridge_enabled(): + with self.assertRaisesRegex(errors.InternalError, + 'EXPECTED_MESSAGE_NEW'): + f() + else: + with self.assertRaisesRegex(errors.InvalidArgumentError, + 'EXPECTED_MESSAGE_OLD'): + f() if __name__ == '__main__':