diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 6138496fe5c..a712dc3b33b 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -64,6 +64,32 @@ enum ColorScheme { kYellow, }; +// Used to indicate how we should treat a given HLOInstruction in the graph -- +// should we treat it like normal, hide it, or highlight it? +enum NodeFilterResult { kNormalNode, kHideNode, kHighlightNode }; + +// NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult. +// It lets callers tell the graph-drawing routines which nodes they want to be +// shown, hidden, or highlighted. +class NodeFilter { + public: + NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {} + + explicit NodeFilter( + std::function filter) + : filter_(std::move(filter)) {} + + bool Show(const HloInstruction* instr) const { + return filter_(instr) != kHideNode; + } + bool Highlight(const HloInstruction* instr) const { + return filter_(instr) == kHighlightNode; + } + + private: + std::function filter_; +}; + // Given a ColorScheme, returns an attribute string for a node of that color. // Sets the node's fill, stroke, and text colors. // @@ -134,7 +160,8 @@ string InstructionSequenceGraph( const std::list>& instructions, bool show_addresses, bool show_layouts, std::vector* intercomputation_edges, - const HloExecutionProfile* hlo_execution_profile) { + const HloExecutionProfile* hlo_execution_profile, + const NodeFilter& filter) { string graph_body; // Create a single "record" node for the parameters. This node is a @@ -157,6 +184,9 @@ string InstructionSequenceGraph( param_node_name = StrCat("parameters_", InstructionId(param_instructions[0])); for (auto& param : param_instructions) { + if (!filter.Show(param)) { + continue; + } string label = StrCat(param->parameter_name(), "\\n", ShapeUtil::HumanString(param->shape())); if (show_addresses) { @@ -178,6 +208,9 @@ string InstructionSequenceGraph( } for (auto& instruction : instructions) { + if (!filter.Show(instruction.get())) { + continue; + } ColorScheme color = kYellow; string shape = "box"; string name = @@ -325,7 +358,7 @@ string InstructionSequenceGraph( case HloOpcode::kReducePrecision: // Make ReducePrecision ops a bit more visible, since typically they // will be inserted as modifications to an existing graph. - color = kDarkRed; + color = kRed; break; } @@ -371,6 +404,12 @@ string InstructionSequenceGraph( } } + // If this node is highlighted, override its formatting. + if (filter.Highlight(instruction.get())) { + shape = "diamond"; + color = kDarkRed; + } + Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n", InstructionId(instruction.get()).c_str(), label.c_str(), shape.c_str(), NodeColorAttributes(color).c_str()); @@ -378,6 +417,10 @@ string InstructionSequenceGraph( // Create edges from the instruction's operands to the instruction. int64 operand_number = 0; for (auto* operand : instruction->operands()) { + if (!filter.Show(operand)) { + ++operand_number; + continue; + } string src; if (operand->opcode() == HloOpcode::kParameter) { // If operand is a parameter, then select the proper partition (port) in @@ -405,10 +448,11 @@ string InstructionSequenceGraph( StrAppend(&graph_body, "label=<fused expression>;\nstyle=\"rounded,filled\";\n" "color=lightgrey;\n"); - StrAppend(&graph_body, InstructionSequenceGraph( - instruction->fused_instructions(), - show_addresses, show_layouts, - intercomputation_edges, hlo_execution_profile), + StrAppend(&graph_body, + InstructionSequenceGraph(instruction->fused_instructions(), + show_addresses, show_layouts, + intercomputation_edges, + hlo_execution_profile, NodeFilter()), "}\n"); string fusion_edge = StrCat(InstructionId(instruction->fused_expression_root()), " -> ", @@ -450,7 +494,8 @@ svg text { string ComputationToDotGraph(const HloComputation& computation, const string& label, bool show_addresses, bool show_layouts, - const HloExecutionProfile* hlo_execution_profile) { + const HloExecutionProfile* hlo_execution_profile, + const NodeFilter& filter) { string graph_label = StrCat(label, "
", computation.name()); if (hlo_execution_profile != nullptr) { auto cycles = hlo_execution_profile->total_cycles_executed(computation); @@ -467,12 +512,31 @@ stylesheet="%s" )", graph_label.c_str(), dot_stylesheet); + std::unordered_set computations_to_dump; + for (const auto& instr : computation.instructions()) { + if (!filter.Show(instr.get())) { + continue; + } + if (instr->opcode() == HloOpcode::kFusion) { + computations_to_dump.insert(instr->fused_instructions_computation()); + } + for (const HloComputation* computation : instr->called_computations()) { + computations_to_dump.insert(computation); + } + } + // Emit embedded computations as subgraph clusters. std::vector intercomputation_edges; - for (auto embedded : computation.MakeEmbeddedComputationsList()) { + for (const HloComputation* embedded : + computation.MakeEmbeddedComputationsList()) { + if (!computations_to_dump.count(embedded)) { + continue; + } + // Don't pass our filter down into the subcomputation -- always render the + // whole thing. string graph_body = InstructionSequenceGraph( embedded->instructions(), show_addresses, show_layouts, - &intercomputation_edges, hlo_execution_profile); + &intercomputation_edges, hlo_execution_profile, NodeFilter()); Appendf(&graph, "subgraph cluster_%s " "{\nstyle=rounded;label=<%s>;labelloc=t;\n%s}\n", @@ -482,7 +546,7 @@ stylesheet="%s" StrAppend(&graph, InstructionSequenceGraph(computation.instructions(), show_addresses, show_layouts, &intercomputation_edges, - hlo_execution_profile)); + hlo_execution_profile, filter)); // Edges between computations (subgraph clusters) must be emitted last for the // graph to be rendered properly for some reason. @@ -555,6 +619,51 @@ class FileGraphRenderer : public GraphRendererInterface { } }; +// Gets roughly all instructions whose distance from root is <= radius. +std::unordered_set GetInstructionsInNeighborhood( + const HloInstruction& root, int64 radius) { + std::unordered_set ret; + + std::deque> worklist; + worklist.push_back({&root, 0}); + + while (!worklist.empty()) { + const HloInstruction* instr; + int64 depth; + std::tie(instr, depth) = worklist.front(); + worklist.pop_front(); + + ret.insert(instr); + if (depth == radius) { + continue; + } + + // Don't traverse into tuples' operands unless the tuple is the root. + // Usually a tuple is the bottommost node in the graph, and so its operands + // are not interesting to the graph at hand. + if (instr == &root || instr->opcode() != HloOpcode::kTuple) { + for (const HloInstruction* operand : instr->operands()) { + if (ret.find(operand) == ret.end()) { + worklist.push_back({operand, depth + 1}); + } + } + } + + // If you're looking at node X, it's probably not interesting that node Y + // also happens to use the same constant, so we don't traverse into + // constants' users. + if (instr->opcode() != HloOpcode::kConstant) { + for (const HloInstruction* user : instr->users()) { + if (ret.find(user) == ret.end()) { + worklist.push_back({user, depth + 1}); + } + } + } + } + + return ret; +} + XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0); } // namespace @@ -575,9 +684,10 @@ string DumpGraph(const HloComputation& computation, const string& label, graph_url = FileGraphRenderer().RenderGraph( graph, GraphRendererInterface::TF_GRAPHDEF, debug_options); } else { - graph = ComputationToDotGraph( - computation, label, debug_options.xla_hlo_graph_addresses(), - debug_options.xla_hlo_graph_layout(), hlo_execution_profile); + graph = ComputationToDotGraph(computation, label, + debug_options.xla_hlo_graph_addresses(), + debug_options.xla_hlo_graph_layout(), + hlo_execution_profile, NodeFilter()); graph_url = GetGraphRenderer()->RenderGraph( graph, GraphRendererInterface::DOT_GRAPH, debug_options); } @@ -586,6 +696,33 @@ string DumpGraph(const HloComputation& computation, const string& label, return graph_url; } +string DumpNeighborhoodAround(const HloInstruction& node, int radius) { + auto debug_options = node.GetModule()->config().debug_options(); + + std::unordered_set neighborhood = + GetInstructionsInNeighborhood(node, radius); + + NodeFilter filter([&](const HloInstruction* instr) { + if (instr == &node) { + return kHighlightNode; + } + if (neighborhood.find(instr) != neighborhood.end()) { + return kNormalNode; + } + return kHideNode; + }); + + string label = + StrCat("Neighborhood of ", radius, " nodes around ", node.name()); + string graph = ComputationToDotGraph( + *node.parent(), label, + /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), + /*show_layouts=*/debug_options.xla_hlo_graph_layout(), + /*hlo_execution_profile=*/nullptr, filter); + return GetGraphRenderer()->RenderGraph( + graph, GraphRendererInterface::DOT_GRAPH, debug_options); +} + void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix) { Env* env = Env::Default(); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index bc404a7a37f..0100d50c050 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -57,6 +57,13 @@ string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr); +// Like DumpGraph, but renders only nodes "near" the given node in the graph. +// +// The number of nodes dumped is controlled by the radius parameter, which +// (roughly) corresponds to the max distance a node may be from the primary node +// before it's omitted from the graph. +string DumpNeighborhoodAround(const HloInstruction& node, int radius); + // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. //