Print a cycle if detected by DFS.

Example output

Directed cycle:
  fusion.48
  get-tuple-element.32
  fusion.62
  get-tuple-element.67
  fusion.44
  get-tuple-element.65
  fusion.48

PiperOrigin-RevId: 266934452
This commit is contained in:
Alexander Belyaev 2019-09-03 08:30:58 -07:00 committed by TensorFlower Gardener
parent 4fc96cee6e
commit d70a2cf2ab

View File

@ -2209,11 +2209,52 @@ string PrintName(const string& name, bool print_ids) {
namespace {
using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
string PrintNameInternal(const string& name, const HloPrintOptions& options) {
return StrCat(options.print_percent() ? "%" : "",
PrintName(name, options.print_ids()));
}
void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) {
// This set contains HloInstructions from the top of `DFSStack` that might
// belong to the cycle, i.e. if DFSStack :=[back,...,child,...,top], then
// `subgraph` := {child,...,top}.
absl::flat_hash_set<const HloInstruction*> subgraph;
while (!dfs_stack->empty() && dfs_stack->back().second != child) {
subgraph.insert(dfs_stack->back().second);
dfs_stack->pop_back();
}
// Start dfs at `child` and find a cycle with all nodes in `subgraph`.
absl::flat_hash_set<const HloInstruction*> visited;
absl::InlinedVector<const HloInstruction*, 16> dfs;
dfs.push_back(child);
while (!dfs.empty()) {
bool found_next_instr = false;
for (const auto& user : dfs.back()->users()) {
if (user == child) {
dfs.push_back(child);
LOG(INFO) << "\n\nDirected cycle:\n "
<< absl::StrJoin(
dfs, "\n ",
[](std::string* out, const HloInstruction* instr) {
out->append(instr->name());
});
return;
}
if (!subgraph.contains(user) || visited.contains(user)) {
continue;
}
visited.insert(user);
dfs.push_back(user);
found_next_instr = true;
}
if (!found_next_instr) {
dfs.pop_back();
}
}
}
} // namespace
string HloInstruction::ToString(const HloPrintOptions& options) const {
@ -2847,8 +2888,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
template Status HloInstruction::Visit(DfsHloVisitor* visitor);
template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
// Push "child" onto the dfs_stack if not already visited. Returns false if a
// cycle was detected, and true otherwise.
template <typename Visitor>
@ -2926,6 +2965,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
const size_t old_dfs_stack_size = dfs_stack.size();
for (HloInstruction* child : current_node->operands()) {
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
PrintCycle(child, &dfs_stack);
return FailedPrecondition(
"A cycle is detected while visiting instruction %s",
current_node->ToString());
@ -2935,6 +2975,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
if (!ignore_control_predecessors) {
for (HloInstruction* child : current_node->control_predecessors()) {
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
PrintCycle(child, &dfs_stack);
return FailedPrecondition(
"A cycle is detected while visiting instruction %s",
current_node->ToString());