Use StackFrame of node during import
Copy the StackFrames over post copying the Graph (we need to remove that still) PiperOrigin-RevId: 352904629 Change-Id: I36dadc7171e26b00726dfc0863d719a4746f40c6
This commit is contained in:
parent
bc2269a05e
commit
03acd8ec0b
@ -377,7 +377,7 @@ class ImporterBase {
|
|||||||
// there are multiple "original_node_names", a FusedLoc is returned. If the
|
// there are multiple "original_node_names", a FusedLoc is returned. If the
|
||||||
// node name couldn't be found in the input DebugInfo, a NameLoc is used as
|
// node name couldn't be found in the input DebugInfo, a NameLoc is used as
|
||||||
// the location.
|
// the location.
|
||||||
mlir::Location GetLocation(const NodeDef& node);
|
mlir::Location GetLocation(const Node& node);
|
||||||
|
|
||||||
// Appends the location string for the node to the error message and returns
|
// Appends the location string for the node to the error message and returns
|
||||||
// the combined error status.
|
// the combined error status.
|
||||||
@ -598,6 +598,31 @@ Status ImporterBase::RemoveBackedges(const Graph& graph) {
|
|||||||
GetReversePostOrder(
|
GetReversePostOrder(
|
||||||
*graph_, &ordered_nodes_,
|
*graph_, &ordered_nodes_,
|
||||||
[](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
|
[](const Node* n1, const Node* n2) { return n1->name() < n2->name(); });
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CopyStackTraces(const Graph& from, Graph* to) {
|
||||||
|
// Copy over the stack traces.
|
||||||
|
// TODO(jpienaar): This really shouldn't be needed, copying the Graph above
|
||||||
|
// and then needing these traversals is unfortunate.
|
||||||
|
std::unordered_map<string, Node*> node_map = from.BuildNodeNameIndex();
|
||||||
|
for (Node* node : to->nodes()) {
|
||||||
|
if (const Node* old_node = node_map[node->name()]) {
|
||||||
|
if (const std::shared_ptr<AbstractStackTrace>& stack =
|
||||||
|
old_node->GetStackTrace()) {
|
||||||
|
DVLOG(2) << "Stack for " << node->name() << " "
|
||||||
|
<< old_node->GetStackTrace()->ToString(
|
||||||
|
AbstractStackTrace::TracePrintingOptions());
|
||||||
|
node->SetStackTrace(stack);
|
||||||
|
} else {
|
||||||
|
DVLOG(1) << "No stack for " << node->name() << " (" << node
|
||||||
|
<< ") in Graph " << &from;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
DVLOG(1) << "No stack for " << node->name() << " (" << node
|
||||||
|
<< ") in Graph " << &from;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -1385,6 +1410,7 @@ Status ImporterBase::ConvertFeedsToPlaceholders(
|
|||||||
|
|
||||||
Status ImporterBase::PrepareConvert(const Graph& graph) {
|
Status ImporterBase::PrepareConvert(const Graph& graph) {
|
||||||
TF_RETURN_IF_ERROR(RemoveBackedges(graph));
|
TF_RETURN_IF_ERROR(RemoveBackedges(graph));
|
||||||
|
TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get()));
|
||||||
|
|
||||||
auto node_name_map = graph_->BuildNodeNameIndex();
|
auto node_name_map = graph_->BuildNodeNameIndex();
|
||||||
|
|
||||||
@ -1579,7 +1605,8 @@ Status ImporterBase::ConvertFunctionArgAndRets(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
|
mlir::Location ImporterBase::GetLocation(const Node& node) {
|
||||||
|
DVLOG(1) << "Getting location for " << node.name() << " " << &node;
|
||||||
// TODO(b/142400497): What is the semantic contract for locations?
|
// TODO(b/142400497): What is the semantic contract for locations?
|
||||||
const auto& debug_info = debug_info_.traces();
|
const auto& debug_info = debug_info_.traces();
|
||||||
|
|
||||||
@ -1599,22 +1626,38 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
|
|||||||
std::string name_for_name_loc =
|
std::string name_for_name_loc =
|
||||||
function_name.empty() ? name.str() : (name + "@" + function_name).str();
|
function_name.empty() ? name.str() : (name + "@" + function_name).str();
|
||||||
auto name_loc_id = mlir::Identifier::get(name_for_name_loc, context_);
|
auto name_loc_id = mlir::Identifier::get(name_for_name_loc, context_);
|
||||||
const auto location_it = debug_info.find(debug_info_key);
|
|
||||||
if (location_it == debug_info.end()) {
|
|
||||||
return mlir::NameLoc::get(name_loc_id, context_);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Location, 4> locations;
|
||||||
|
// Prefer stack traces if available, fallback to debug info if not, and then
|
||||||
|
// finally to just name.
|
||||||
|
if (auto stack_trace = node.GetStackTrace()) {
|
||||||
|
DVLOG(1) << "Stack available for " << node.name();
|
||||||
|
absl::Span<const StackFrame> frames = stack_trace->ToFrames();
|
||||||
|
locations.reserve(frames.size());
|
||||||
|
for (const StackFrame& frame : llvm::reverse(frames)) {
|
||||||
|
auto file_name = mlir::Identifier::get(frame.file_name, context_);
|
||||||
|
// Use col 1 as there is no column info in StackTrace.
|
||||||
|
auto file_line_loc = mlir::FileLineColLoc::get(
|
||||||
|
file_name, frame.line_number, 1, context_);
|
||||||
|
locations.push_back(file_line_loc);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
DVLOG(1) << "No stack trace for " << node.name();
|
||||||
|
const auto location_it = debug_info.find(debug_info_key);
|
||||||
|
if (location_it != debug_info.end()) {
|
||||||
|
DVLOG(1) << "Available serialized debug info for " << node.name();
|
||||||
// Convert the stack trace to a chain of mlir::CallSiteLocs.
|
// Convert the stack trace to a chain of mlir::CallSiteLocs.
|
||||||
const auto& trace = location_it->second;
|
const auto& trace = location_it->second;
|
||||||
llvm::SmallVector<mlir::Location, 4> locations;
|
|
||||||
locations.reserve(trace.file_line_cols_size());
|
locations.reserve(trace.file_line_cols_size());
|
||||||
for (const auto& location : trace.file_line_cols()) {
|
for (const auto& location : trace.file_line_cols()) {
|
||||||
const auto& file = debug_info_.files(location.file_index());
|
const auto& file = debug_info_.files(location.file_index());
|
||||||
auto file_name = mlir::Identifier::get(file, context_);
|
auto file_name = mlir::Identifier::get(file, context_);
|
||||||
auto file_line_loc = mlir::FileLineColLoc::get(file_name, location.line(),
|
auto file_line_loc = mlir::FileLineColLoc::get(
|
||||||
location.col(), context_);
|
file_name, location.line(), location.col(), context_);
|
||||||
locations.push_back(file_line_loc);
|
locations.push_back(file_line_loc);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If there are no locations in the stack trace, fall back to just a
|
// If there are no locations in the stack trace, fall back to just a
|
||||||
// NameLoc with no child.
|
// NameLoc with no child.
|
||||||
@ -1636,16 +1679,20 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
|
|||||||
// Hence, we use node name as location to keep it unique.
|
// Hence, we use node name as location to keep it unique.
|
||||||
// TODO(prakalps): In future the plan is to use tokens to pair source/sink
|
// TODO(prakalps): In future the plan is to use tokens to pair source/sink
|
||||||
// nodes. Then NextIteration nodes would not need to be handled separately.
|
// nodes. Then NextIteration nodes would not need to be handled separately.
|
||||||
if (node_def.op() == "NextIteration")
|
if (node.type_string() == "NextIteration")
|
||||||
return create_location(node_def.name(), function_name_for_debug_info_);
|
return create_location(node.name(), function_name_for_debug_info_);
|
||||||
|
|
||||||
|
if (node.GetStackTrace())
|
||||||
|
return create_location(node.name(), function_name_for_debug_info_);
|
||||||
|
|
||||||
|
const auto& node_def = node.def();
|
||||||
auto original_nodes =
|
auto original_nodes =
|
||||||
node_def.experimental_debug_info().original_node_names();
|
node_def.experimental_debug_info().original_node_names();
|
||||||
auto original_funcs =
|
auto original_funcs =
|
||||||
node_def.experimental_debug_info().original_func_names();
|
node_def.experimental_debug_info().original_func_names();
|
||||||
|
|
||||||
if (original_nodes.empty()) {
|
if (original_nodes.empty()) {
|
||||||
return create_location(node_def.name(), function_name_for_debug_info_);
|
return create_location(node.name(), function_name_for_debug_info_);
|
||||||
} else {
|
} else {
|
||||||
// If the original nodes are defined, then we use them to get a list of
|
// If the original nodes are defined, then we use them to get a list of
|
||||||
// call sites, and then fuse them to a single fused location, with the name
|
// call sites, and then fuse them to a single fused location, with the name
|
||||||
@ -1661,14 +1708,14 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
|
|||||||
}
|
}
|
||||||
// store the name of the node_def
|
// store the name of the node_def
|
||||||
node_locations.push_back(
|
node_locations.push_back(
|
||||||
create_location(node_def.name(), function_name_for_debug_info_));
|
create_location(node.name(), function_name_for_debug_info_));
|
||||||
return mlir::FusedLoc::get(node_locations, context_);
|
return mlir::FusedLoc::get(node_locations, context_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ImporterBase::EmitErrorWithLocationStr(const Node& node,
|
Status ImporterBase::EmitErrorWithLocationStr(const Node& node,
|
||||||
const Status& error_status) {
|
const Status& error_status) {
|
||||||
const mlir::Location location = GetLocation(node.def());
|
const mlir::Location location = GetLocation(node);
|
||||||
mlir::emitError(location);
|
mlir::emitError(location);
|
||||||
return error_handler_.Combine(error_status);
|
return error_handler_.Combine(error_status);
|
||||||
}
|
}
|
||||||
@ -1832,8 +1879,7 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
|||||||
op_name = op_name + ".sink";
|
op_name = op_name + ".sink";
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& node_def = node.def();
|
mlir::OperationState result(GetLocation(node), op_name);
|
||||||
mlir::OperationState result(GetLocation(node_def), op_name);
|
|
||||||
for (int i = 0; i < node.num_outputs(); ++i) {
|
for (int i = 0; i < node.num_outputs(); ++i) {
|
||||||
// The backedge has been removed, so we shouldn't count the corresponding
|
// The backedge has been removed, so we shouldn't count the corresponding
|
||||||
// output from the src node when converting to an operation.
|
// output from the src node when converting to an operation.
|
||||||
@ -1937,6 +1983,7 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
|||||||
&result.attributes));
|
&result.attributes));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto& node_def = node.def();
|
||||||
result.attributes.push_back(builder_.getNamedAttr(
|
result.attributes.push_back(builder_.getNamedAttr(
|
||||||
"device", builder_.getStringAttr(std::string(node_def.device()))));
|
"device", builder_.getStringAttr(std::string(node_def.device()))));
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class MLIRConcreteFunctionImportTest(test.TestCase):
|
|||||||
tensor_spec.TensorSpec(None, dtypes.float32))
|
tensor_spec.TensorSpec(None, dtypes.float32))
|
||||||
mlir_module = mlir.convert_function(concrete_function, show_debug_info=True)
|
mlir_module = mlir.convert_function(concrete_function, show_debug_info=True)
|
||||||
self.assertRegex(mlir_module, r'func @.*sqr.*\(')
|
self.assertRegex(mlir_module, r'func @.*sqr.*\(')
|
||||||
self.assertRegex(mlir_module, r'loc\(')
|
self.assertRegex(mlir_module, r'callsite\(".*mlir_test.py":')
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testImportWithCall(self):
|
def testImportWithCall(self):
|
||||||
|
@ -871,22 +871,28 @@ class DefFunctionTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
f(constant_op.constant(1))
|
f(constant_op.constant(1))
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
|
||||||
'support stack traces')
|
|
||||||
def testTensorArrayErrorMessage(self):
|
def testTensorArrayErrorMessage(self):
|
||||||
with ops.device('device:{}:0'.format(self.device)):
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
@def_function.function(jit_compile=True)
|
@def_function.function(jit_compile=True)
|
||||||
def f():
|
def f():
|
||||||
ta = tensor_array_ops.TensorArray(
|
# The error message as old and new bridge differ in which op they flag.
|
||||||
|
# The one points to the creation of the unitialized tensor array, the
|
||||||
|
# other is the use of the unitialized tensor array.
|
||||||
|
ta = tensor_array_ops.TensorArray( # EXPECTED_MESSAGE_NEW
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
size=2,
|
size=2,
|
||||||
dynamic_size=True,
|
dynamic_size=True,
|
||||||
element_shape=(None,))
|
element_shape=(None,))
|
||||||
return ta.concat() # EXPECTED_MESSAGE
|
return ta.concat() # EXPECTED_MESSAGE_OLD
|
||||||
|
|
||||||
|
if test_util.is_mlir_bridge_enabled():
|
||||||
|
with self.assertRaisesRegex(errors.InternalError,
|
||||||
|
'EXPECTED_MESSAGE_NEW'):
|
||||||
|
f()
|
||||||
|
else:
|
||||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
'EXPECTED_MESSAGE'):
|
'EXPECTED_MESSAGE_OLD'):
|
||||||
f()
|
f()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user