Print a cycle if detected by DFS.
Example output Directed cycle: fusion.48 get-tuple-element.32 fusion.62 get-tuple-element.67 fusion.44 get-tuple-element.65 fusion.48 PiperOrigin-RevId: 266934452
This commit is contained in:
parent
4fc96cee6e
commit
d70a2cf2ab
@ -2209,11 +2209,52 @@ string PrintName(const string& name, bool print_ids) {
|
||||
|
||||
namespace {
|
||||
|
||||
using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
|
||||
|
||||
string PrintNameInternal(const string& name, const HloPrintOptions& options) {
|
||||
return StrCat(options.print_percent() ? "%" : "",
|
||||
PrintName(name, options.print_ids()));
|
||||
}
|
||||
|
||||
void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) {
|
||||
// This set contains HloInstructions from the top of `DFSStack` that might
|
||||
// belong to the cycle, i.e. if DFSStack :=[back,...,child,...,top], then
|
||||
// `subgraph` := {child,...,top}.
|
||||
absl::flat_hash_set<const HloInstruction*> subgraph;
|
||||
while (!dfs_stack->empty() && dfs_stack->back().second != child) {
|
||||
subgraph.insert(dfs_stack->back().second);
|
||||
dfs_stack->pop_back();
|
||||
}
|
||||
// Start dfs at `child` and find a cycle with all nodes in `subgraph`.
|
||||
absl::flat_hash_set<const HloInstruction*> visited;
|
||||
absl::InlinedVector<const HloInstruction*, 16> dfs;
|
||||
dfs.push_back(child);
|
||||
while (!dfs.empty()) {
|
||||
bool found_next_instr = false;
|
||||
for (const auto& user : dfs.back()->users()) {
|
||||
if (user == child) {
|
||||
dfs.push_back(child);
|
||||
LOG(INFO) << "\n\nDirected cycle:\n "
|
||||
<< absl::StrJoin(
|
||||
dfs, "\n ",
|
||||
[](std::string* out, const HloInstruction* instr) {
|
||||
out->append(instr->name());
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (!subgraph.contains(user) || visited.contains(user)) {
|
||||
continue;
|
||||
}
|
||||
visited.insert(user);
|
||||
dfs.push_back(user);
|
||||
found_next_instr = true;
|
||||
}
|
||||
if (!found_next_instr) {
|
||||
dfs.pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
string HloInstruction::ToString(const HloPrintOptions& options) const {
|
||||
@ -2847,8 +2888,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
template Status HloInstruction::Visit(DfsHloVisitor* visitor);
|
||||
template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
|
||||
|
||||
using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
|
||||
|
||||
// Push "child" onto the dfs_stack if not already visited. Returns false if a
|
||||
// cycle was detected, and true otherwise.
|
||||
template <typename Visitor>
|
||||
@ -2926,6 +2965,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
|
||||
const size_t old_dfs_stack_size = dfs_stack.size();
|
||||
for (HloInstruction* child : current_node->operands()) {
|
||||
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
|
||||
PrintCycle(child, &dfs_stack);
|
||||
return FailedPrecondition(
|
||||
"A cycle is detected while visiting instruction %s",
|
||||
current_node->ToString());
|
||||
@ -2935,6 +2975,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
|
||||
if (!ignore_control_predecessors) {
|
||||
for (HloInstruction* child : current_node->control_predecessors()) {
|
||||
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
|
||||
PrintCycle(child, &dfs_stack);
|
||||
return FailedPrecondition(
|
||||
"A cycle is detected while visiting instruction %s",
|
||||
current_node->ToString());
|
||||
|
Loading…
Reference in New Issue
Block a user