From 386f4aef0d05489cc3a4cdc01470533849569dba Mon Sep 17 00:00:00 2001
From: Justin Lebar <jlebar@google.com>
Date: Thu, 20 Jul 2017 18:09:16 -0700
Subject: [PATCH] Add hlo_graph_dumper::GetInstructionsInNeighborhood, which
 lets you graph the nodes that are "near" a particular node.

PiperOrigin-RevId: 162692461
---
 .../compiler/xla/service/hlo_graph_dumper.cc  | 163 ++++++++++++++++--
 .../compiler/xla/service/hlo_graph_dumper.h   |   7 +
 2 files changed, 157 insertions(+), 13 deletions(-)

diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 6138496fe5c..a712dc3b33b 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -64,6 +64,32 @@ enum ColorScheme {
   kYellow,
 };
 
+// Used to indicate how we should treat a given HLOInstruction in the graph --
+// should we treat it like normal, hide it, or highlight it?
+enum NodeFilterResult { kNormalNode, kHideNode, kHighlightNode };
+
+// NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
+// It lets callers tell the graph-drawing routines which nodes they want to be
+// shown, hidden, or highlighted.
+class NodeFilter {
+ public:
+  NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
+
+  explicit NodeFilter(
+      std::function<NodeFilterResult(const HloInstruction* instr)> filter)
+      : filter_(std::move(filter)) {}
+
+  bool Show(const HloInstruction* instr) const {
+    return filter_(instr) != kHideNode;
+  }
+  bool Highlight(const HloInstruction* instr) const {
+    return filter_(instr) == kHighlightNode;
+  }
+
+ private:
+  std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
+};
+
 // Given a ColorScheme, returns an attribute string for a node of that color.
 // Sets the node's fill, stroke, and text colors.
 //
@@ -134,7 +160,8 @@ string InstructionSequenceGraph(
     const std::list<std::unique_ptr<HloInstruction>>& instructions,
     bool show_addresses, bool show_layouts,
     std::vector<string>* intercomputation_edges,
-    const HloExecutionProfile* hlo_execution_profile) {
+    const HloExecutionProfile* hlo_execution_profile,
+    const NodeFilter& filter) {
   string graph_body;
 
   // Create a single "record" node for the parameters. This node is a
@@ -157,6 +184,9 @@ string InstructionSequenceGraph(
     param_node_name =
         StrCat("parameters_", InstructionId(param_instructions[0]));
     for (auto& param : param_instructions) {
+      if (!filter.Show(param)) {
+        continue;
+      }
       string label = StrCat(param->parameter_name(), "\\n",
                             ShapeUtil::HumanString(param->shape()));
       if (show_addresses) {
@@ -178,6 +208,9 @@ string InstructionSequenceGraph(
   }
 
   for (auto& instruction : instructions) {
+    if (!filter.Show(instruction.get())) {
+      continue;
+    }
     ColorScheme color = kYellow;
     string shape = "box";
     string name =
@@ -325,7 +358,7 @@ string InstructionSequenceGraph(
       case HloOpcode::kReducePrecision:
         // Make ReducePrecision ops a bit more visible, since typically they
         // will be inserted as modifications to an existing graph.
-        color = kDarkRed;
+        color = kRed;
         break;
     }
 
@@ -371,6 +404,12 @@ string InstructionSequenceGraph(
       }
     }
 
+    // If this node is highlighted, override its formatting.
+    if (filter.Highlight(instruction.get())) {
+      shape = "diamond";
+      color = kDarkRed;
+    }
+
     Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n",
             InstructionId(instruction.get()).c_str(), label.c_str(),
             shape.c_str(), NodeColorAttributes(color).c_str());
@@ -378,6 +417,10 @@ string InstructionSequenceGraph(
     // Create edges from the instruction's operands to the instruction.
     int64 operand_number = 0;
     for (auto* operand : instruction->operands()) {
+      if (!filter.Show(operand)) {
+        ++operand_number;
+        continue;
+      }
       string src;
       if (operand->opcode() == HloOpcode::kParameter) {
         // If operand is a parameter, then select the proper partition (port) in
@@ -405,10 +448,11 @@ string InstructionSequenceGraph(
       StrAppend(&graph_body,
                 "label=<<b>fused expression</b>>;\nstyle=\"rounded,filled\";\n"
                 "color=lightgrey;\n");
-      StrAppend(&graph_body, InstructionSequenceGraph(
-                                 instruction->fused_instructions(),
-                                 show_addresses, show_layouts,
-                                 intercomputation_edges, hlo_execution_profile),
+      StrAppend(&graph_body,
+                InstructionSequenceGraph(instruction->fused_instructions(),
+                                         show_addresses, show_layouts,
+                                         intercomputation_edges,
+                                         hlo_execution_profile, NodeFilter()),
                 "}\n");
       string fusion_edge =
           StrCat(InstructionId(instruction->fused_expression_root()), " -> ",
@@ -450,7 +494,8 @@ svg text {
 string ComputationToDotGraph(const HloComputation& computation,
                              const string& label, bool show_addresses,
                              bool show_layouts,
-                             const HloExecutionProfile* hlo_execution_profile) {
+                             const HloExecutionProfile* hlo_execution_profile,
+                             const NodeFilter& filter) {
   string graph_label = StrCat(label, "<br/>", computation.name());
   if (hlo_execution_profile != nullptr) {
     auto cycles = hlo_execution_profile->total_cycles_executed(computation);
@@ -467,12 +512,31 @@ stylesheet="%s"
 )",
       graph_label.c_str(), dot_stylesheet);
 
+  std::unordered_set<const HloComputation*> computations_to_dump;
+  for (const auto& instr : computation.instructions()) {
+    if (!filter.Show(instr.get())) {
+      continue;
+    }
+    if (instr->opcode() == HloOpcode::kFusion) {
+      computations_to_dump.insert(instr->fused_instructions_computation());
+    }
+    for (const HloComputation* computation : instr->called_computations()) {
+      computations_to_dump.insert(computation);
+    }
+  }
+
   // Emit embedded computations as subgraph clusters.
   std::vector<string> intercomputation_edges;
-  for (auto embedded : computation.MakeEmbeddedComputationsList()) {
+  for (const HloComputation* embedded :
+       computation.MakeEmbeddedComputationsList()) {
+    if (!computations_to_dump.count(embedded)) {
+      continue;
+    }
+    // Don't pass our filter down into the subcomputation -- always render the
+    // whole thing.
     string graph_body = InstructionSequenceGraph(
         embedded->instructions(), show_addresses, show_layouts,
-        &intercomputation_edges, hlo_execution_profile);
+        &intercomputation_edges, hlo_execution_profile, NodeFilter());
     Appendf(&graph,
             "subgraph cluster_%s "
             "{\nstyle=rounded;label=<<b>%s</b>>;labelloc=t;\n%s}\n",
@@ -482,7 +546,7 @@ stylesheet="%s"
   StrAppend(&graph,
             InstructionSequenceGraph(computation.instructions(), show_addresses,
                                      show_layouts, &intercomputation_edges,
-                                     hlo_execution_profile));
+                                     hlo_execution_profile, filter));
 
   // Edges between computations (subgraph clusters) must be emitted last for the
   // graph to be rendered properly for some reason.
@@ -555,6 +619,51 @@ class FileGraphRenderer : public GraphRendererInterface {
   }
 };
 
+// Gets roughly all instructions whose distance from root is <= radius.
+std::unordered_set<const HloInstruction*> GetInstructionsInNeighborhood(
+    const HloInstruction& root, int64 radius) {
+  std::unordered_set<const HloInstruction*> ret;
+
+  std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
+  worklist.push_back({&root, 0});
+
+  while (!worklist.empty()) {
+    const HloInstruction* instr;
+    int64 depth;
+    std::tie(instr, depth) = worklist.front();
+    worklist.pop_front();
+
+    ret.insert(instr);
+    if (depth == radius) {
+      continue;
+    }
+
+    // Don't traverse into tuples' operands unless the tuple is the root.
+    // Usually a tuple is the bottommost node in the graph, and so its operands
+    // are not interesting to the graph at hand.
+    if (instr == &root || instr->opcode() != HloOpcode::kTuple) {
+      for (const HloInstruction* operand : instr->operands()) {
+        if (ret.find(operand) == ret.end()) {
+          worklist.push_back({operand, depth + 1});
+        }
+      }
+    }
+
+    // If you're looking at node X, it's probably not interesting that node Y
+    // also happens to use the same constant, so we don't traverse into
+    // constants' users.
+    if (instr->opcode() != HloOpcode::kConstant) {
+      for (const HloInstruction* user : instr->users()) {
+        if (ret.find(user) == ret.end()) {
+          worklist.push_back({user, depth + 1});
+        }
+      }
+    }
+  }
+
+  return ret;
+}
+
 XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0);
 
 }  // namespace
@@ -575,9 +684,10 @@ string DumpGraph(const HloComputation& computation, const string& label,
     graph_url = FileGraphRenderer().RenderGraph(
         graph, GraphRendererInterface::TF_GRAPHDEF, debug_options);
   } else {
-    graph = ComputationToDotGraph(
-        computation, label, debug_options.xla_hlo_graph_addresses(),
-        debug_options.xla_hlo_graph_layout(), hlo_execution_profile);
+    graph = ComputationToDotGraph(computation, label,
+                                  debug_options.xla_hlo_graph_addresses(),
+                                  debug_options.xla_hlo_graph_layout(),
+                                  hlo_execution_profile, NodeFilter());
     graph_url = GetGraphRenderer()->RenderGraph(
         graph, GraphRendererInterface::DOT_GRAPH, debug_options);
   }
@@ -586,6 +696,33 @@ string DumpGraph(const HloComputation& computation, const string& label,
   return graph_url;
 }
 
+string DumpNeighborhoodAround(const HloInstruction& node, int radius) {
+  auto debug_options = node.GetModule()->config().debug_options();
+
+  std::unordered_set<const HloInstruction*> neighborhood =
+      GetInstructionsInNeighborhood(node, radius);
+
+  NodeFilter filter([&](const HloInstruction* instr) {
+    if (instr == &node) {
+      return kHighlightNode;
+    }
+    if (neighborhood.find(instr) != neighborhood.end()) {
+      return kNormalNode;
+    }
+    return kHideNode;
+  });
+
+  string label =
+      StrCat("Neighborhood of ", radius, " nodes around ", node.name());
+  string graph = ComputationToDotGraph(
+      *node.parent(), label,
+      /*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
+      /*show_layouts=*/debug_options.xla_hlo_graph_layout(),
+      /*hlo_execution_profile=*/nullptr, filter);
+  return GetGraphRenderer()->RenderGraph(
+      graph, GraphRendererInterface::DOT_GRAPH, debug_options);
+}
+
 void DumpText(const HloModule& module, const string& label,
               const string& directory_path, bool do_prefix) {
   Env* env = Env::Default();
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
index bc404a7a37f..0100d50c050 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
@@ -57,6 +57,13 @@ string DumpGraph(const HloComputation& computation, const string& label,
                  const DebugOptions& debug_options,
                  const HloExecutionProfile* hlo_execution_profile = nullptr);
 
+// Like DumpGraph, but renders only nodes "near" the given node in the graph.
+//
+// The number of nodes dumped is controlled by the radius parameter, which
+// (roughly) corresponds to the max distance a node may be from the primary node
+// before it's omitted from the graph.
+string DumpNeighborhoodAround(const HloInstruction& node, int radius);
+
 // Dumps the HloModule::ToString() as a file into the provided directory path
 // suffixed with the provided label.
 //