Add hlo_graph_dumper::GetInstructionsInNeighborhood, which lets you
graph the nodes that are "near" a particular node. PiperOrigin-RevId: 162692461
This commit is contained in:
parent
9cf6fe1ad3
commit
386f4aef0d
@ -64,6 +64,32 @@ enum ColorScheme {
|
|||||||
kYellow,
|
kYellow,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Used to indicate how we should treat a given HLOInstruction in the graph --
|
||||||
|
// should we treat it like normal, hide it, or highlight it?
|
||||||
|
enum NodeFilterResult { kNormalNode, kHideNode, kHighlightNode };
|
||||||
|
|
||||||
|
// NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
|
||||||
|
// It lets callers tell the graph-drawing routines which nodes they want to be
|
||||||
|
// shown, hidden, or highlighted.
|
||||||
|
class NodeFilter {
|
||||||
|
public:
|
||||||
|
NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
|
||||||
|
|
||||||
|
explicit NodeFilter(
|
||||||
|
std::function<NodeFilterResult(const HloInstruction* instr)> filter)
|
||||||
|
: filter_(std::move(filter)) {}
|
||||||
|
|
||||||
|
bool Show(const HloInstruction* instr) const {
|
||||||
|
return filter_(instr) != kHideNode;
|
||||||
|
}
|
||||||
|
bool Highlight(const HloInstruction* instr) const {
|
||||||
|
return filter_(instr) == kHighlightNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
|
||||||
|
};
|
||||||
|
|
||||||
// Given a ColorScheme, returns an attribute string for a node of that color.
|
// Given a ColorScheme, returns an attribute string for a node of that color.
|
||||||
// Sets the node's fill, stroke, and text colors.
|
// Sets the node's fill, stroke, and text colors.
|
||||||
//
|
//
|
||||||
@ -134,7 +160,8 @@ string InstructionSequenceGraph(
|
|||||||
const std::list<std::unique_ptr<HloInstruction>>& instructions,
|
const std::list<std::unique_ptr<HloInstruction>>& instructions,
|
||||||
bool show_addresses, bool show_layouts,
|
bool show_addresses, bool show_layouts,
|
||||||
std::vector<string>* intercomputation_edges,
|
std::vector<string>* intercomputation_edges,
|
||||||
const HloExecutionProfile* hlo_execution_profile) {
|
const HloExecutionProfile* hlo_execution_profile,
|
||||||
|
const NodeFilter& filter) {
|
||||||
string graph_body;
|
string graph_body;
|
||||||
|
|
||||||
// Create a single "record" node for the parameters. This node is a
|
// Create a single "record" node for the parameters. This node is a
|
||||||
@ -157,6 +184,9 @@ string InstructionSequenceGraph(
|
|||||||
param_node_name =
|
param_node_name =
|
||||||
StrCat("parameters_", InstructionId(param_instructions[0]));
|
StrCat("parameters_", InstructionId(param_instructions[0]));
|
||||||
for (auto& param : param_instructions) {
|
for (auto& param : param_instructions) {
|
||||||
|
if (!filter.Show(param)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
string label = StrCat(param->parameter_name(), "\\n",
|
string label = StrCat(param->parameter_name(), "\\n",
|
||||||
ShapeUtil::HumanString(param->shape()));
|
ShapeUtil::HumanString(param->shape()));
|
||||||
if (show_addresses) {
|
if (show_addresses) {
|
||||||
@ -178,6 +208,9 @@ string InstructionSequenceGraph(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (auto& instruction : instructions) {
|
for (auto& instruction : instructions) {
|
||||||
|
if (!filter.Show(instruction.get())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
ColorScheme color = kYellow;
|
ColorScheme color = kYellow;
|
||||||
string shape = "box";
|
string shape = "box";
|
||||||
string name =
|
string name =
|
||||||
@ -325,7 +358,7 @@ string InstructionSequenceGraph(
|
|||||||
case HloOpcode::kReducePrecision:
|
case HloOpcode::kReducePrecision:
|
||||||
// Make ReducePrecision ops a bit more visible, since typically they
|
// Make ReducePrecision ops a bit more visible, since typically they
|
||||||
// will be inserted as modifications to an existing graph.
|
// will be inserted as modifications to an existing graph.
|
||||||
color = kDarkRed;
|
color = kRed;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -371,6 +404,12 @@ string InstructionSequenceGraph(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If this node is highlighted, override its formatting.
|
||||||
|
if (filter.Highlight(instruction.get())) {
|
||||||
|
shape = "diamond";
|
||||||
|
color = kDarkRed;
|
||||||
|
}
|
||||||
|
|
||||||
Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n",
|
Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n",
|
||||||
InstructionId(instruction.get()).c_str(), label.c_str(),
|
InstructionId(instruction.get()).c_str(), label.c_str(),
|
||||||
shape.c_str(), NodeColorAttributes(color).c_str());
|
shape.c_str(), NodeColorAttributes(color).c_str());
|
||||||
@ -378,6 +417,10 @@ string InstructionSequenceGraph(
|
|||||||
// Create edges from the instruction's operands to the instruction.
|
// Create edges from the instruction's operands to the instruction.
|
||||||
int64 operand_number = 0;
|
int64 operand_number = 0;
|
||||||
for (auto* operand : instruction->operands()) {
|
for (auto* operand : instruction->operands()) {
|
||||||
|
if (!filter.Show(operand)) {
|
||||||
|
++operand_number;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
string src;
|
string src;
|
||||||
if (operand->opcode() == HloOpcode::kParameter) {
|
if (operand->opcode() == HloOpcode::kParameter) {
|
||||||
// If operand is a parameter, then select the proper partition (port) in
|
// If operand is a parameter, then select the proper partition (port) in
|
||||||
@ -405,10 +448,11 @@ string InstructionSequenceGraph(
|
|||||||
StrAppend(&graph_body,
|
StrAppend(&graph_body,
|
||||||
"label=<<b>fused expression</b>>;\nstyle=\"rounded,filled\";\n"
|
"label=<<b>fused expression</b>>;\nstyle=\"rounded,filled\";\n"
|
||||||
"color=lightgrey;\n");
|
"color=lightgrey;\n");
|
||||||
StrAppend(&graph_body, InstructionSequenceGraph(
|
StrAppend(&graph_body,
|
||||||
instruction->fused_instructions(),
|
InstructionSequenceGraph(instruction->fused_instructions(),
|
||||||
show_addresses, show_layouts,
|
show_addresses, show_layouts,
|
||||||
intercomputation_edges, hlo_execution_profile),
|
intercomputation_edges,
|
||||||
|
hlo_execution_profile, NodeFilter()),
|
||||||
"}\n");
|
"}\n");
|
||||||
string fusion_edge =
|
string fusion_edge =
|
||||||
StrCat(InstructionId(instruction->fused_expression_root()), " -> ",
|
StrCat(InstructionId(instruction->fused_expression_root()), " -> ",
|
||||||
@ -450,7 +494,8 @@ svg text {
|
|||||||
string ComputationToDotGraph(const HloComputation& computation,
|
string ComputationToDotGraph(const HloComputation& computation,
|
||||||
const string& label, bool show_addresses,
|
const string& label, bool show_addresses,
|
||||||
bool show_layouts,
|
bool show_layouts,
|
||||||
const HloExecutionProfile* hlo_execution_profile) {
|
const HloExecutionProfile* hlo_execution_profile,
|
||||||
|
const NodeFilter& filter) {
|
||||||
string graph_label = StrCat(label, "<br/>", computation.name());
|
string graph_label = StrCat(label, "<br/>", computation.name());
|
||||||
if (hlo_execution_profile != nullptr) {
|
if (hlo_execution_profile != nullptr) {
|
||||||
auto cycles = hlo_execution_profile->total_cycles_executed(computation);
|
auto cycles = hlo_execution_profile->total_cycles_executed(computation);
|
||||||
@ -467,12 +512,31 @@ stylesheet="%s"
|
|||||||
)",
|
)",
|
||||||
graph_label.c_str(), dot_stylesheet);
|
graph_label.c_str(), dot_stylesheet);
|
||||||
|
|
||||||
|
std::unordered_set<const HloComputation*> computations_to_dump;
|
||||||
|
for (const auto& instr : computation.instructions()) {
|
||||||
|
if (!filter.Show(instr.get())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (instr->opcode() == HloOpcode::kFusion) {
|
||||||
|
computations_to_dump.insert(instr->fused_instructions_computation());
|
||||||
|
}
|
||||||
|
for (const HloComputation* computation : instr->called_computations()) {
|
||||||
|
computations_to_dump.insert(computation);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Emit embedded computations as subgraph clusters.
|
// Emit embedded computations as subgraph clusters.
|
||||||
std::vector<string> intercomputation_edges;
|
std::vector<string> intercomputation_edges;
|
||||||
for (auto embedded : computation.MakeEmbeddedComputationsList()) {
|
for (const HloComputation* embedded :
|
||||||
|
computation.MakeEmbeddedComputationsList()) {
|
||||||
|
if (!computations_to_dump.count(embedded)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Don't pass our filter down into the subcomputation -- always render the
|
||||||
|
// whole thing.
|
||||||
string graph_body = InstructionSequenceGraph(
|
string graph_body = InstructionSequenceGraph(
|
||||||
embedded->instructions(), show_addresses, show_layouts,
|
embedded->instructions(), show_addresses, show_layouts,
|
||||||
&intercomputation_edges, hlo_execution_profile);
|
&intercomputation_edges, hlo_execution_profile, NodeFilter());
|
||||||
Appendf(&graph,
|
Appendf(&graph,
|
||||||
"subgraph cluster_%s "
|
"subgraph cluster_%s "
|
||||||
"{\nstyle=rounded;label=<<b>%s</b>>;labelloc=t;\n%s}\n",
|
"{\nstyle=rounded;label=<<b>%s</b>>;labelloc=t;\n%s}\n",
|
||||||
@ -482,7 +546,7 @@ stylesheet="%s"
|
|||||||
StrAppend(&graph,
|
StrAppend(&graph,
|
||||||
InstructionSequenceGraph(computation.instructions(), show_addresses,
|
InstructionSequenceGraph(computation.instructions(), show_addresses,
|
||||||
show_layouts, &intercomputation_edges,
|
show_layouts, &intercomputation_edges,
|
||||||
hlo_execution_profile));
|
hlo_execution_profile, filter));
|
||||||
|
|
||||||
// Edges between computations (subgraph clusters) must be emitted last for the
|
// Edges between computations (subgraph clusters) must be emitted last for the
|
||||||
// graph to be rendered properly for some reason.
|
// graph to be rendered properly for some reason.
|
||||||
@ -555,6 +619,51 @@ class FileGraphRenderer : public GraphRendererInterface {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Gets roughly all instructions whose distance from root is <= radius.
|
||||||
|
std::unordered_set<const HloInstruction*> GetInstructionsInNeighborhood(
|
||||||
|
const HloInstruction& root, int64 radius) {
|
||||||
|
std::unordered_set<const HloInstruction*> ret;
|
||||||
|
|
||||||
|
std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
|
||||||
|
worklist.push_back({&root, 0});
|
||||||
|
|
||||||
|
while (!worklist.empty()) {
|
||||||
|
const HloInstruction* instr;
|
||||||
|
int64 depth;
|
||||||
|
std::tie(instr, depth) = worklist.front();
|
||||||
|
worklist.pop_front();
|
||||||
|
|
||||||
|
ret.insert(instr);
|
||||||
|
if (depth == radius) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't traverse into tuples' operands unless the tuple is the root.
|
||||||
|
// Usually a tuple is the bottommost node in the graph, and so its operands
|
||||||
|
// are not interesting to the graph at hand.
|
||||||
|
if (instr == &root || instr->opcode() != HloOpcode::kTuple) {
|
||||||
|
for (const HloInstruction* operand : instr->operands()) {
|
||||||
|
if (ret.find(operand) == ret.end()) {
|
||||||
|
worklist.push_back({operand, depth + 1});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If you're looking at node X, it's probably not interesting that node Y
|
||||||
|
// also happens to use the same constant, so we don't traverse into
|
||||||
|
// constants' users.
|
||||||
|
if (instr->opcode() != HloOpcode::kConstant) {
|
||||||
|
for (const HloInstruction* user : instr->users()) {
|
||||||
|
if (ret.find(user) == ret.end()) {
|
||||||
|
worklist.push_back({user, depth + 1});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0);
|
XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -575,9 +684,10 @@ string DumpGraph(const HloComputation& computation, const string& label,
|
|||||||
graph_url = FileGraphRenderer().RenderGraph(
|
graph_url = FileGraphRenderer().RenderGraph(
|
||||||
graph, GraphRendererInterface::TF_GRAPHDEF, debug_options);
|
graph, GraphRendererInterface::TF_GRAPHDEF, debug_options);
|
||||||
} else {
|
} else {
|
||||||
graph = ComputationToDotGraph(
|
graph = ComputationToDotGraph(computation, label,
|
||||||
computation, label, debug_options.xla_hlo_graph_addresses(),
|
debug_options.xla_hlo_graph_addresses(),
|
||||||
debug_options.xla_hlo_graph_layout(), hlo_execution_profile);
|
debug_options.xla_hlo_graph_layout(),
|
||||||
|
hlo_execution_profile, NodeFilter());
|
||||||
graph_url = GetGraphRenderer()->RenderGraph(
|
graph_url = GetGraphRenderer()->RenderGraph(
|
||||||
graph, GraphRendererInterface::DOT_GRAPH, debug_options);
|
graph, GraphRendererInterface::DOT_GRAPH, debug_options);
|
||||||
}
|
}
|
||||||
@ -586,6 +696,33 @@ string DumpGraph(const HloComputation& computation, const string& label,
|
|||||||
return graph_url;
|
return graph_url;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string DumpNeighborhoodAround(const HloInstruction& node, int radius) {
|
||||||
|
auto debug_options = node.GetModule()->config().debug_options();
|
||||||
|
|
||||||
|
std::unordered_set<const HloInstruction*> neighborhood =
|
||||||
|
GetInstructionsInNeighborhood(node, radius);
|
||||||
|
|
||||||
|
NodeFilter filter([&](const HloInstruction* instr) {
|
||||||
|
if (instr == &node) {
|
||||||
|
return kHighlightNode;
|
||||||
|
}
|
||||||
|
if (neighborhood.find(instr) != neighborhood.end()) {
|
||||||
|
return kNormalNode;
|
||||||
|
}
|
||||||
|
return kHideNode;
|
||||||
|
});
|
||||||
|
|
||||||
|
string label =
|
||||||
|
StrCat("Neighborhood of ", radius, " nodes around ", node.name());
|
||||||
|
string graph = ComputationToDotGraph(
|
||||||
|
*node.parent(), label,
|
||||||
|
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
|
||||||
|
/*show_layouts=*/debug_options.xla_hlo_graph_layout(),
|
||||||
|
/*hlo_execution_profile=*/nullptr, filter);
|
||||||
|
return GetGraphRenderer()->RenderGraph(
|
||||||
|
graph, GraphRendererInterface::DOT_GRAPH, debug_options);
|
||||||
|
}
|
||||||
|
|
||||||
void DumpText(const HloModule& module, const string& label,
|
void DumpText(const HloModule& module, const string& label,
|
||||||
const string& directory_path, bool do_prefix) {
|
const string& directory_path, bool do_prefix) {
|
||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
|
@ -57,6 +57,13 @@ string DumpGraph(const HloComputation& computation, const string& label,
|
|||||||
const DebugOptions& debug_options,
|
const DebugOptions& debug_options,
|
||||||
const HloExecutionProfile* hlo_execution_profile = nullptr);
|
const HloExecutionProfile* hlo_execution_profile = nullptr);
|
||||||
|
|
||||||
|
// Like DumpGraph, but renders only nodes "near" the given node in the graph.
|
||||||
|
//
|
||||||
|
// The number of nodes dumped is controlled by the radius parameter, which
|
||||||
|
// (roughly) corresponds to the max distance a node may be from the primary node
|
||||||
|
// before it's omitted from the graph.
|
||||||
|
string DumpNeighborhoodAround(const HloInstruction& node, int radius);
|
||||||
|
|
||||||
// Dumps the HloModule::ToString() as a file into the provided directory path
|
// Dumps the HloModule::ToString() as a file into the provided directory path
|
||||||
// suffixed with the provided label.
|
// suffixed with the provided label.
|
||||||
//
|
//
|
||||||
|
Loading…
Reference in New Issue
Block a user