Add hlo_graph_dumper::GetInstructionsInNeighborhood, which lets you
graph the nodes that are "near" a particular node. PiperOrigin-RevId: 162692461
This commit is contained in:
parent
9cf6fe1ad3
commit
386f4aef0d
@ -64,6 +64,32 @@ enum ColorScheme {
|
||||
kYellow,
|
||||
};
|
||||
|
||||
// Used to indicate how we should treat a given HLOInstruction in the graph --
|
||||
// should we treat it like normal, hide it, or highlight it?
|
||||
enum NodeFilterResult { kNormalNode, kHideNode, kHighlightNode };
|
||||
|
||||
// NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
|
||||
// It lets callers tell the graph-drawing routines which nodes they want to be
|
||||
// shown, hidden, or highlighted.
|
||||
class NodeFilter {
|
||||
public:
|
||||
NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
|
||||
|
||||
explicit NodeFilter(
|
||||
std::function<NodeFilterResult(const HloInstruction* instr)> filter)
|
||||
: filter_(std::move(filter)) {}
|
||||
|
||||
bool Show(const HloInstruction* instr) const {
|
||||
return filter_(instr) != kHideNode;
|
||||
}
|
||||
bool Highlight(const HloInstruction* instr) const {
|
||||
return filter_(instr) == kHighlightNode;
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
|
||||
};
|
||||
|
||||
// Given a ColorScheme, returns an attribute string for a node of that color.
|
||||
// Sets the node's fill, stroke, and text colors.
|
||||
//
|
||||
@ -134,7 +160,8 @@ string InstructionSequenceGraph(
|
||||
const std::list<std::unique_ptr<HloInstruction>>& instructions,
|
||||
bool show_addresses, bool show_layouts,
|
||||
std::vector<string>* intercomputation_edges,
|
||||
const HloExecutionProfile* hlo_execution_profile) {
|
||||
const HloExecutionProfile* hlo_execution_profile,
|
||||
const NodeFilter& filter) {
|
||||
string graph_body;
|
||||
|
||||
// Create a single "record" node for the parameters. This node is a
|
||||
@ -157,6 +184,9 @@ string InstructionSequenceGraph(
|
||||
param_node_name =
|
||||
StrCat("parameters_", InstructionId(param_instructions[0]));
|
||||
for (auto& param : param_instructions) {
|
||||
if (!filter.Show(param)) {
|
||||
continue;
|
||||
}
|
||||
string label = StrCat(param->parameter_name(), "\\n",
|
||||
ShapeUtil::HumanString(param->shape()));
|
||||
if (show_addresses) {
|
||||
@ -178,6 +208,9 @@ string InstructionSequenceGraph(
|
||||
}
|
||||
|
||||
for (auto& instruction : instructions) {
|
||||
if (!filter.Show(instruction.get())) {
|
||||
continue;
|
||||
}
|
||||
ColorScheme color = kYellow;
|
||||
string shape = "box";
|
||||
string name =
|
||||
@ -325,7 +358,7 @@ string InstructionSequenceGraph(
|
||||
case HloOpcode::kReducePrecision:
|
||||
// Make ReducePrecision ops a bit more visible, since typically they
|
||||
// will be inserted as modifications to an existing graph.
|
||||
color = kDarkRed;
|
||||
color = kRed;
|
||||
break;
|
||||
}
|
||||
|
||||
@ -371,6 +404,12 @@ string InstructionSequenceGraph(
|
||||
}
|
||||
}
|
||||
|
||||
// If this node is highlighted, override its formatting.
|
||||
if (filter.Highlight(instruction.get())) {
|
||||
shape = "diamond";
|
||||
color = kDarkRed;
|
||||
}
|
||||
|
||||
Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n",
|
||||
InstructionId(instruction.get()).c_str(), label.c_str(),
|
||||
shape.c_str(), NodeColorAttributes(color).c_str());
|
||||
@ -378,6 +417,10 @@ string InstructionSequenceGraph(
|
||||
// Create edges from the instruction's operands to the instruction.
|
||||
int64 operand_number = 0;
|
||||
for (auto* operand : instruction->operands()) {
|
||||
if (!filter.Show(operand)) {
|
||||
++operand_number;
|
||||
continue;
|
||||
}
|
||||
string src;
|
||||
if (operand->opcode() == HloOpcode::kParameter) {
|
||||
// If operand is a parameter, then select the proper partition (port) in
|
||||
@ -405,10 +448,11 @@ string InstructionSequenceGraph(
|
||||
StrAppend(&graph_body,
|
||||
"label=<<b>fused expression</b>>;\nstyle=\"rounded,filled\";\n"
|
||||
"color=lightgrey;\n");
|
||||
StrAppend(&graph_body, InstructionSequenceGraph(
|
||||
instruction->fused_instructions(),
|
||||
show_addresses, show_layouts,
|
||||
intercomputation_edges, hlo_execution_profile),
|
||||
StrAppend(&graph_body,
|
||||
InstructionSequenceGraph(instruction->fused_instructions(),
|
||||
show_addresses, show_layouts,
|
||||
intercomputation_edges,
|
||||
hlo_execution_profile, NodeFilter()),
|
||||
"}\n");
|
||||
string fusion_edge =
|
||||
StrCat(InstructionId(instruction->fused_expression_root()), " -> ",
|
||||
@ -450,7 +494,8 @@ svg text {
|
||||
string ComputationToDotGraph(const HloComputation& computation,
|
||||
const string& label, bool show_addresses,
|
||||
bool show_layouts,
|
||||
const HloExecutionProfile* hlo_execution_profile) {
|
||||
const HloExecutionProfile* hlo_execution_profile,
|
||||
const NodeFilter& filter) {
|
||||
string graph_label = StrCat(label, "<br/>", computation.name());
|
||||
if (hlo_execution_profile != nullptr) {
|
||||
auto cycles = hlo_execution_profile->total_cycles_executed(computation);
|
||||
@ -467,12 +512,31 @@ stylesheet="%s"
|
||||
)",
|
||||
graph_label.c_str(), dot_stylesheet);
|
||||
|
||||
std::unordered_set<const HloComputation*> computations_to_dump;
|
||||
for (const auto& instr : computation.instructions()) {
|
||||
if (!filter.Show(instr.get())) {
|
||||
continue;
|
||||
}
|
||||
if (instr->opcode() == HloOpcode::kFusion) {
|
||||
computations_to_dump.insert(instr->fused_instructions_computation());
|
||||
}
|
||||
for (const HloComputation* computation : instr->called_computations()) {
|
||||
computations_to_dump.insert(computation);
|
||||
}
|
||||
}
|
||||
|
||||
// Emit embedded computations as subgraph clusters.
|
||||
std::vector<string> intercomputation_edges;
|
||||
for (auto embedded : computation.MakeEmbeddedComputationsList()) {
|
||||
for (const HloComputation* embedded :
|
||||
computation.MakeEmbeddedComputationsList()) {
|
||||
if (!computations_to_dump.count(embedded)) {
|
||||
continue;
|
||||
}
|
||||
// Don't pass our filter down into the subcomputation -- always render the
|
||||
// whole thing.
|
||||
string graph_body = InstructionSequenceGraph(
|
||||
embedded->instructions(), show_addresses, show_layouts,
|
||||
&intercomputation_edges, hlo_execution_profile);
|
||||
&intercomputation_edges, hlo_execution_profile, NodeFilter());
|
||||
Appendf(&graph,
|
||||
"subgraph cluster_%s "
|
||||
"{\nstyle=rounded;label=<<b>%s</b>>;labelloc=t;\n%s}\n",
|
||||
@ -482,7 +546,7 @@ stylesheet="%s"
|
||||
StrAppend(&graph,
|
||||
InstructionSequenceGraph(computation.instructions(), show_addresses,
|
||||
show_layouts, &intercomputation_edges,
|
||||
hlo_execution_profile));
|
||||
hlo_execution_profile, filter));
|
||||
|
||||
// Edges between computations (subgraph clusters) must be emitted last for the
|
||||
// graph to be rendered properly for some reason.
|
||||
@ -555,6 +619,51 @@ class FileGraphRenderer : public GraphRendererInterface {
|
||||
}
|
||||
};
|
||||
|
||||
// Gets roughly all instructions whose distance from root is <= radius.
|
||||
std::unordered_set<const HloInstruction*> GetInstructionsInNeighborhood(
|
||||
const HloInstruction& root, int64 radius) {
|
||||
std::unordered_set<const HloInstruction*> ret;
|
||||
|
||||
std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
|
||||
worklist.push_back({&root, 0});
|
||||
|
||||
while (!worklist.empty()) {
|
||||
const HloInstruction* instr;
|
||||
int64 depth;
|
||||
std::tie(instr, depth) = worklist.front();
|
||||
worklist.pop_front();
|
||||
|
||||
ret.insert(instr);
|
||||
if (depth == radius) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Don't traverse into tuples' operands unless the tuple is the root.
|
||||
// Usually a tuple is the bottommost node in the graph, and so its operands
|
||||
// are not interesting to the graph at hand.
|
||||
if (instr == &root || instr->opcode() != HloOpcode::kTuple) {
|
||||
for (const HloInstruction* operand : instr->operands()) {
|
||||
if (ret.find(operand) == ret.end()) {
|
||||
worklist.push_back({operand, depth + 1});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If you're looking at node X, it's probably not interesting that node Y
|
||||
// also happens to use the same constant, so we don't traverse into
|
||||
// constants' users.
|
||||
if (instr->opcode() != HloOpcode::kConstant) {
|
||||
for (const HloInstruction* user : instr->users()) {
|
||||
if (ret.find(user) == ret.end()) {
|
||||
worklist.push_back({user, depth + 1});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0);
|
||||
|
||||
} // namespace
|
||||
@ -575,9 +684,10 @@ string DumpGraph(const HloComputation& computation, const string& label,
|
||||
graph_url = FileGraphRenderer().RenderGraph(
|
||||
graph, GraphRendererInterface::TF_GRAPHDEF, debug_options);
|
||||
} else {
|
||||
graph = ComputationToDotGraph(
|
||||
computation, label, debug_options.xla_hlo_graph_addresses(),
|
||||
debug_options.xla_hlo_graph_layout(), hlo_execution_profile);
|
||||
graph = ComputationToDotGraph(computation, label,
|
||||
debug_options.xla_hlo_graph_addresses(),
|
||||
debug_options.xla_hlo_graph_layout(),
|
||||
hlo_execution_profile, NodeFilter());
|
||||
graph_url = GetGraphRenderer()->RenderGraph(
|
||||
graph, GraphRendererInterface::DOT_GRAPH, debug_options);
|
||||
}
|
||||
@ -586,6 +696,33 @@ string DumpGraph(const HloComputation& computation, const string& label,
|
||||
return graph_url;
|
||||
}
|
||||
|
||||
string DumpNeighborhoodAround(const HloInstruction& node, int radius) {
|
||||
auto debug_options = node.GetModule()->config().debug_options();
|
||||
|
||||
std::unordered_set<const HloInstruction*> neighborhood =
|
||||
GetInstructionsInNeighborhood(node, radius);
|
||||
|
||||
NodeFilter filter([&](const HloInstruction* instr) {
|
||||
if (instr == &node) {
|
||||
return kHighlightNode;
|
||||
}
|
||||
if (neighborhood.find(instr) != neighborhood.end()) {
|
||||
return kNormalNode;
|
||||
}
|
||||
return kHideNode;
|
||||
});
|
||||
|
||||
string label =
|
||||
StrCat("Neighborhood of ", radius, " nodes around ", node.name());
|
||||
string graph = ComputationToDotGraph(
|
||||
*node.parent(), label,
|
||||
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
|
||||
/*show_layouts=*/debug_options.xla_hlo_graph_layout(),
|
||||
/*hlo_execution_profile=*/nullptr, filter);
|
||||
return GetGraphRenderer()->RenderGraph(
|
||||
graph, GraphRendererInterface::DOT_GRAPH, debug_options);
|
||||
}
|
||||
|
||||
void DumpText(const HloModule& module, const string& label,
|
||||
const string& directory_path, bool do_prefix) {
|
||||
Env* env = Env::Default();
|
||||
|
@ -57,6 +57,13 @@ string DumpGraph(const HloComputation& computation, const string& label,
|
||||
const DebugOptions& debug_options,
|
||||
const HloExecutionProfile* hlo_execution_profile = nullptr);
|
||||
|
||||
// Like DumpGraph, but renders only nodes "near" the given node in the graph.
|
||||
//
|
||||
// The number of nodes dumped is controlled by the radius parameter, which
|
||||
// (roughly) corresponds to the max distance a node may be from the primary node
|
||||
// before it's omitted from the graph.
|
||||
string DumpNeighborhoodAround(const HloInstruction& node, int radius);
|
||||
|
||||
// Dumps the HloModule::ToString() as a file into the provided directory path
|
||||
// suffixed with the provided label.
|
||||
//
|
||||
|
Loading…
Reference in New Issue
Block a user