From 7790cd6c2fbef5753c5f036e636cef130dbb02c3 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Mon, 21 Dec 2020 11:54:36 -0800 Subject: [PATCH] [TF2XLA] Fix a segfault in propagating stack traces in XlaKernelCreator The node might not be alive when we are showing the stack trace PiperOrigin-RevId: 348503788 Change-Id: I80e9e6fa3f10aea7841a4d3bfad51b6a9854edb7 --- tensorflow/compiler/jit/compilability_check_util.cc | 11 ++++++----- tensorflow/compiler/jit/compilability_check_util.h | 4 ++-- tensorflow/compiler/jit/xla_kernel_creator.cc | 7 ++++--- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 7d853e584ee..fb4c187f5bd 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -152,10 +152,11 @@ RecursiveCompilabilityChecker::FindUncompilableNodes( if (node_stack_trace != nullptr) { for (const auto& frame : *node_stack_trace) { stack_trace.emplace_back( - StackFrameView{frame.name, frame.function_name, frame.n}); + StackFrameView{frame.name, frame.function_name, frame.stack_trace}); } } - stack_trace.emplace_back(StackFrameView{node.name(), "", &node}); + stack_trace.emplace_back( + StackFrameView{node.name(), "", node.GetStackTrace()}); RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes; IsCompilableNode(node, lib_runtime, &stack_trace, @@ -175,7 +176,7 @@ RecursiveCompilabilityChecker::FindUncompilableNodes( if (node_stack_trace != nullptr) { for (const auto& frame : *node_stack_trace) { stack_trace.emplace_back( - StackFrameView{frame.name, frame.function_name, frame.n}); + StackFrameView{frame.name, frame.function_name, frame.stack_trace}); } } stack_trace.emplace_back(StackFrameView{call_def.name(), "", nullptr}); @@ -361,7 +362,7 @@ bool RecursiveCompilabilityChecker::IsCompilableCall( bool is_compilable = true; for (const Node* node : fbody->graph->op_nodes()) { stack_trace->emplace_back( - StackFrameView{node->name(), function.name(), node}); + StackFrameView{node->name(), function.name(), node->GetStackTrace()}); is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace, &function, uncompilable_nodes); stack_trace->pop_back(); @@ -586,7 +587,7 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( return StackFrame{ std::string(stack_element.name), std::string(stack_element.function_name), - stack_element.n}; + stack_element.stack_trace}; }); node_info.name = std::string(stack_trace.back().name); diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index d341ef9456e..027253636f0 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -62,7 +62,7 @@ class RecursiveCompilabilityChecker { struct StackFrame { std::string name; std::string function_name; - const Node* n = nullptr; + std::shared_ptr stack_trace; }; // Contains information about uncompilable node inside a function body. @@ -197,7 +197,7 @@ class RecursiveCompilabilityChecker { struct StackFrameView { absl::string_view name; absl::string_view function_name; - const Node* n = nullptr; + std::shared_ptr stack_trace; }; bool IsCompilableNode( diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 8cc84f694f0..602c2d2cbfe 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -123,13 +123,14 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr, std::string node_message = absl::StrCat( "\n", node_info.name, ": ", node_info.uncompilable_reason, "\n", "The op is created at:\n"); - const Node* n = node_info.stack_trace.back().n; - if (n && n->GetStackTrace()) { + if (node_info.stack_trace.back().stack_trace) { AbstractStackTrace::TracePrintingOptions opts; opts.show_line_contents = true; opts.filter_common_prefix = true; opts.drop_internal_frames = true; - absl::StrAppend(&node_message, n->GetStackTrace()->ToString(opts)); + absl::StrAppend( + &node_message, + node_info.stack_trace.back().stack_trace->ToString(opts)); } else { absl::StrAppend(&node_message, "\n"); }