Print a cycle if detected by DFS.
Example output Directed cycle: fusion.48 get-tuple-element.32 fusion.62 get-tuple-element.67 fusion.44 get-tuple-element.65 fusion.48 PiperOrigin-RevId: 266934452
This commit is contained in:
parent
4fc96cee6e
commit
d70a2cf2ab
@ -2209,11 +2209,52 @@ string PrintName(const string& name, bool print_ids) {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
|
||||||
|
|
||||||
string PrintNameInternal(const string& name, const HloPrintOptions& options) {
|
string PrintNameInternal(const string& name, const HloPrintOptions& options) {
|
||||||
return StrCat(options.print_percent() ? "%" : "",
|
return StrCat(options.print_percent() ? "%" : "",
|
||||||
PrintName(name, options.print_ids()));
|
PrintName(name, options.print_ids()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) {
|
||||||
|
// This set contains HloInstructions from the top of `DFSStack` that might
|
||||||
|
// belong to the cycle, i.e. if DFSStack :=[back,...,child,...,top], then
|
||||||
|
// `subgraph` := {child,...,top}.
|
||||||
|
absl::flat_hash_set<const HloInstruction*> subgraph;
|
||||||
|
while (!dfs_stack->empty() && dfs_stack->back().second != child) {
|
||||||
|
subgraph.insert(dfs_stack->back().second);
|
||||||
|
dfs_stack->pop_back();
|
||||||
|
}
|
||||||
|
// Start dfs at `child` and find a cycle with all nodes in `subgraph`.
|
||||||
|
absl::flat_hash_set<const HloInstruction*> visited;
|
||||||
|
absl::InlinedVector<const HloInstruction*, 16> dfs;
|
||||||
|
dfs.push_back(child);
|
||||||
|
while (!dfs.empty()) {
|
||||||
|
bool found_next_instr = false;
|
||||||
|
for (const auto& user : dfs.back()->users()) {
|
||||||
|
if (user == child) {
|
||||||
|
dfs.push_back(child);
|
||||||
|
LOG(INFO) << "\n\nDirected cycle:\n "
|
||||||
|
<< absl::StrJoin(
|
||||||
|
dfs, "\n ",
|
||||||
|
[](std::string* out, const HloInstruction* instr) {
|
||||||
|
out->append(instr->name());
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!subgraph.contains(user) || visited.contains(user)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
visited.insert(user);
|
||||||
|
dfs.push_back(user);
|
||||||
|
found_next_instr = true;
|
||||||
|
}
|
||||||
|
if (!found_next_instr) {
|
||||||
|
dfs.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
string HloInstruction::ToString(const HloPrintOptions& options) const {
|
string HloInstruction::ToString(const HloPrintOptions& options) const {
|
||||||
@ -2847,8 +2888,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
|||||||
template Status HloInstruction::Visit(DfsHloVisitor* visitor);
|
template Status HloInstruction::Visit(DfsHloVisitor* visitor);
|
||||||
template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
|
template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
|
||||||
|
|
||||||
using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
|
|
||||||
|
|
||||||
// Push "child" onto the dfs_stack if not already visited. Returns false if a
|
// Push "child" onto the dfs_stack if not already visited. Returns false if a
|
||||||
// cycle was detected, and true otherwise.
|
// cycle was detected, and true otherwise.
|
||||||
template <typename Visitor>
|
template <typename Visitor>
|
||||||
@ -2926,6 +2965,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
|
|||||||
const size_t old_dfs_stack_size = dfs_stack.size();
|
const size_t old_dfs_stack_size = dfs_stack.size();
|
||||||
for (HloInstruction* child : current_node->operands()) {
|
for (HloInstruction* child : current_node->operands()) {
|
||||||
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
|
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
|
||||||
|
PrintCycle(child, &dfs_stack);
|
||||||
return FailedPrecondition(
|
return FailedPrecondition(
|
||||||
"A cycle is detected while visiting instruction %s",
|
"A cycle is detected while visiting instruction %s",
|
||||||
current_node->ToString());
|
current_node->ToString());
|
||||||
@ -2935,6 +2975,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
|
|||||||
if (!ignore_control_predecessors) {
|
if (!ignore_control_predecessors) {
|
||||||
for (HloInstruction* child : current_node->control_predecessors()) {
|
for (HloInstruction* child : current_node->control_predecessors()) {
|
||||||
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
|
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
|
||||||
|
PrintCycle(child, &dfs_stack);
|
||||||
return FailedPrecondition(
|
return FailedPrecondition(
|
||||||
"A cycle is detected while visiting instruction %s",
|
"A cycle is detected while visiting instruction %s",
|
||||||
current_node->ToString());
|
current_node->ToString());
|
||||||
|
Loading…
Reference in New Issue
Block a user