Add hlo_graph_dumper::GetInstructionsInNeighborhood, which lets you

graph the nodes that are "near" a particular node.

PiperOrigin-RevId: 162692461
This commit is contained in:
Justin Lebar 2017-07-20 18:09:16 -07:00 committed by TensorFlower Gardener
parent 9cf6fe1ad3
commit 386f4aef0d
2 changed files with 157 additions and 13 deletions

View File

@ -64,6 +64,32 @@ enum ColorScheme {
kYellow, kYellow,
}; };
// Used to indicate how we should treat a given HLOInstruction in the graph --
// should we treat it like normal, hide it, or highlight it?
enum NodeFilterResult { kNormalNode, kHideNode, kHighlightNode };
// NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
// It lets callers tell the graph-drawing routines which nodes they want to be
// shown, hidden, or highlighted.
class NodeFilter {
public:
NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
explicit NodeFilter(
std::function<NodeFilterResult(const HloInstruction* instr)> filter)
: filter_(std::move(filter)) {}
bool Show(const HloInstruction* instr) const {
return filter_(instr) != kHideNode;
}
bool Highlight(const HloInstruction* instr) const {
return filter_(instr) == kHighlightNode;
}
private:
std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
};
// Given a ColorScheme, returns an attribute string for a node of that color. // Given a ColorScheme, returns an attribute string for a node of that color.
// Sets the node's fill, stroke, and text colors. // Sets the node's fill, stroke, and text colors.
// //
@ -134,7 +160,8 @@ string InstructionSequenceGraph(
const std::list<std::unique_ptr<HloInstruction>>& instructions, const std::list<std::unique_ptr<HloInstruction>>& instructions,
bool show_addresses, bool show_layouts, bool show_addresses, bool show_layouts,
std::vector<string>* intercomputation_edges, std::vector<string>* intercomputation_edges,
const HloExecutionProfile* hlo_execution_profile) { const HloExecutionProfile* hlo_execution_profile,
const NodeFilter& filter) {
string graph_body; string graph_body;
// Create a single "record" node for the parameters. This node is a // Create a single "record" node for the parameters. This node is a
@ -157,6 +184,9 @@ string InstructionSequenceGraph(
param_node_name = param_node_name =
StrCat("parameters_", InstructionId(param_instructions[0])); StrCat("parameters_", InstructionId(param_instructions[0]));
for (auto& param : param_instructions) { for (auto& param : param_instructions) {
if (!filter.Show(param)) {
continue;
}
string label = StrCat(param->parameter_name(), "\\n", string label = StrCat(param->parameter_name(), "\\n",
ShapeUtil::HumanString(param->shape())); ShapeUtil::HumanString(param->shape()));
if (show_addresses) { if (show_addresses) {
@ -178,6 +208,9 @@ string InstructionSequenceGraph(
} }
for (auto& instruction : instructions) { for (auto& instruction : instructions) {
if (!filter.Show(instruction.get())) {
continue;
}
ColorScheme color = kYellow; ColorScheme color = kYellow;
string shape = "box"; string shape = "box";
string name = string name =
@ -325,7 +358,7 @@ string InstructionSequenceGraph(
case HloOpcode::kReducePrecision: case HloOpcode::kReducePrecision:
// Make ReducePrecision ops a bit more visible, since typically they // Make ReducePrecision ops a bit more visible, since typically they
// will be inserted as modifications to an existing graph. // will be inserted as modifications to an existing graph.
color = kDarkRed; color = kRed;
break; break;
} }
@ -371,6 +404,12 @@ string InstructionSequenceGraph(
} }
} }
// If this node is highlighted, override its formatting.
if (filter.Highlight(instruction.get())) {
shape = "diamond";
color = kDarkRed;
}
Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n", Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n",
InstructionId(instruction.get()).c_str(), label.c_str(), InstructionId(instruction.get()).c_str(), label.c_str(),
shape.c_str(), NodeColorAttributes(color).c_str()); shape.c_str(), NodeColorAttributes(color).c_str());
@ -378,6 +417,10 @@ string InstructionSequenceGraph(
// Create edges from the instruction's operands to the instruction. // Create edges from the instruction's operands to the instruction.
int64 operand_number = 0; int64 operand_number = 0;
for (auto* operand : instruction->operands()) { for (auto* operand : instruction->operands()) {
if (!filter.Show(operand)) {
++operand_number;
continue;
}
string src; string src;
if (operand->opcode() == HloOpcode::kParameter) { if (operand->opcode() == HloOpcode::kParameter) {
// If operand is a parameter, then select the proper partition (port) in // If operand is a parameter, then select the proper partition (port) in
@ -405,10 +448,11 @@ string InstructionSequenceGraph(
StrAppend(&graph_body, StrAppend(&graph_body,
"label=<<b>fused expression</b>>;\nstyle=\"rounded,filled\";\n" "label=<<b>fused expression</b>>;\nstyle=\"rounded,filled\";\n"
"color=lightgrey;\n"); "color=lightgrey;\n");
StrAppend(&graph_body, InstructionSequenceGraph( StrAppend(&graph_body,
instruction->fused_instructions(), InstructionSequenceGraph(instruction->fused_instructions(),
show_addresses, show_layouts, show_addresses, show_layouts,
intercomputation_edges, hlo_execution_profile), intercomputation_edges,
hlo_execution_profile, NodeFilter()),
"}\n"); "}\n");
string fusion_edge = string fusion_edge =
StrCat(InstructionId(instruction->fused_expression_root()), " -> ", StrCat(InstructionId(instruction->fused_expression_root()), " -> ",
@ -450,7 +494,8 @@ svg text {
string ComputationToDotGraph(const HloComputation& computation, string ComputationToDotGraph(const HloComputation& computation,
const string& label, bool show_addresses, const string& label, bool show_addresses,
bool show_layouts, bool show_layouts,
const HloExecutionProfile* hlo_execution_profile) { const HloExecutionProfile* hlo_execution_profile,
const NodeFilter& filter) {
string graph_label = StrCat(label, "<br/>", computation.name()); string graph_label = StrCat(label, "<br/>", computation.name());
if (hlo_execution_profile != nullptr) { if (hlo_execution_profile != nullptr) {
auto cycles = hlo_execution_profile->total_cycles_executed(computation); auto cycles = hlo_execution_profile->total_cycles_executed(computation);
@ -467,12 +512,31 @@ stylesheet="%s"
)", )",
graph_label.c_str(), dot_stylesheet); graph_label.c_str(), dot_stylesheet);
std::unordered_set<const HloComputation*> computations_to_dump;
for (const auto& instr : computation.instructions()) {
if (!filter.Show(instr.get())) {
continue;
}
if (instr->opcode() == HloOpcode::kFusion) {
computations_to_dump.insert(instr->fused_instructions_computation());
}
for (const HloComputation* computation : instr->called_computations()) {
computations_to_dump.insert(computation);
}
}
// Emit embedded computations as subgraph clusters. // Emit embedded computations as subgraph clusters.
std::vector<string> intercomputation_edges; std::vector<string> intercomputation_edges;
for (auto embedded : computation.MakeEmbeddedComputationsList()) { for (const HloComputation* embedded :
computation.MakeEmbeddedComputationsList()) {
if (!computations_to_dump.count(embedded)) {
continue;
}
// Don't pass our filter down into the subcomputation -- always render the
// whole thing.
string graph_body = InstructionSequenceGraph( string graph_body = InstructionSequenceGraph(
embedded->instructions(), show_addresses, show_layouts, embedded->instructions(), show_addresses, show_layouts,
&intercomputation_edges, hlo_execution_profile); &intercomputation_edges, hlo_execution_profile, NodeFilter());
Appendf(&graph, Appendf(&graph,
"subgraph cluster_%s " "subgraph cluster_%s "
"{\nstyle=rounded;label=<<b>%s</b>>;labelloc=t;\n%s}\n", "{\nstyle=rounded;label=<<b>%s</b>>;labelloc=t;\n%s}\n",
@ -482,7 +546,7 @@ stylesheet="%s"
StrAppend(&graph, StrAppend(&graph,
InstructionSequenceGraph(computation.instructions(), show_addresses, InstructionSequenceGraph(computation.instructions(), show_addresses,
show_layouts, &intercomputation_edges, show_layouts, &intercomputation_edges,
hlo_execution_profile)); hlo_execution_profile, filter));
// Edges between computations (subgraph clusters) must be emitted last for the // Edges between computations (subgraph clusters) must be emitted last for the
// graph to be rendered properly for some reason. // graph to be rendered properly for some reason.
@ -555,6 +619,51 @@ class FileGraphRenderer : public GraphRendererInterface {
} }
}; };
// Gets roughly all instructions whose distance from root is <= radius.
std::unordered_set<const HloInstruction*> GetInstructionsInNeighborhood(
const HloInstruction& root, int64 radius) {
std::unordered_set<const HloInstruction*> ret;
std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
worklist.push_back({&root, 0});
while (!worklist.empty()) {
const HloInstruction* instr;
int64 depth;
std::tie(instr, depth) = worklist.front();
worklist.pop_front();
ret.insert(instr);
if (depth == radius) {
continue;
}
// Don't traverse into tuples' operands unless the tuple is the root.
// Usually a tuple is the bottommost node in the graph, and so its operands
// are not interesting to the graph at hand.
if (instr == &root || instr->opcode() != HloOpcode::kTuple) {
for (const HloInstruction* operand : instr->operands()) {
if (ret.find(operand) == ret.end()) {
worklist.push_back({operand, depth + 1});
}
}
}
// If you're looking at node X, it's probably not interesting that node Y
// also happens to use the same constant, so we don't traverse into
// constants' users.
if (instr->opcode() != HloOpcode::kConstant) {
for (const HloInstruction* user : instr->users()) {
if (ret.find(user) == ret.end()) {
worklist.push_back({user, depth + 1});
}
}
}
}
return ret;
}
XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0); XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0);
} // namespace } // namespace
@ -575,9 +684,10 @@ string DumpGraph(const HloComputation& computation, const string& label,
graph_url = FileGraphRenderer().RenderGraph( graph_url = FileGraphRenderer().RenderGraph(
graph, GraphRendererInterface::TF_GRAPHDEF, debug_options); graph, GraphRendererInterface::TF_GRAPHDEF, debug_options);
} else { } else {
graph = ComputationToDotGraph( graph = ComputationToDotGraph(computation, label,
computation, label, debug_options.xla_hlo_graph_addresses(), debug_options.xla_hlo_graph_addresses(),
debug_options.xla_hlo_graph_layout(), hlo_execution_profile); debug_options.xla_hlo_graph_layout(),
hlo_execution_profile, NodeFilter());
graph_url = GetGraphRenderer()->RenderGraph( graph_url = GetGraphRenderer()->RenderGraph(
graph, GraphRendererInterface::DOT_GRAPH, debug_options); graph, GraphRendererInterface::DOT_GRAPH, debug_options);
} }
@ -586,6 +696,33 @@ string DumpGraph(const HloComputation& computation, const string& label,
return graph_url; return graph_url;
} }
string DumpNeighborhoodAround(const HloInstruction& node, int radius) {
auto debug_options = node.GetModule()->config().debug_options();
std::unordered_set<const HloInstruction*> neighborhood =
GetInstructionsInNeighborhood(node, radius);
NodeFilter filter([&](const HloInstruction* instr) {
if (instr == &node) {
return kHighlightNode;
}
if (neighborhood.find(instr) != neighborhood.end()) {
return kNormalNode;
}
return kHideNode;
});
string label =
StrCat("Neighborhood of ", radius, " nodes around ", node.name());
string graph = ComputationToDotGraph(
*node.parent(), label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
/*show_layouts=*/debug_options.xla_hlo_graph_layout(),
/*hlo_execution_profile=*/nullptr, filter);
return GetGraphRenderer()->RenderGraph(
graph, GraphRendererInterface::DOT_GRAPH, debug_options);
}
void DumpText(const HloModule& module, const string& label, void DumpText(const HloModule& module, const string& label,
const string& directory_path, bool do_prefix) { const string& directory_path, bool do_prefix) {
Env* env = Env::Default(); Env* env = Env::Default();

View File

@ -57,6 +57,13 @@ string DumpGraph(const HloComputation& computation, const string& label,
const DebugOptions& debug_options, const DebugOptions& debug_options,
const HloExecutionProfile* hlo_execution_profile = nullptr); const HloExecutionProfile* hlo_execution_profile = nullptr);
// Like DumpGraph, but renders only nodes "near" the given node in the graph.
//
// The number of nodes dumped is controlled by the radius parameter, which
// (roughly) corresponds to the max distance a node may be from the primary node
// before it's omitted from the graph.
string DumpNeighborhoodAround(const HloInstruction& node, int radius);
// Dumps the HloModule::ToString() as a file into the provided directory path // Dumps the HloModule::ToString() as a file into the provided directory path
// suffixed with the provided label. // suffixed with the provided label.
// //