diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 07b3369d5c1..3c5113eb4d9 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -312,11 +312,12 @@ optional MatchTrivialComputation(const HloComputation* computation) { class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, - bool show_addresses, const HloExecutionProfile* profile, - NodeFilter filter) + bool show_addresses, bool show_metadata, + const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), label_(label.ToString()), show_addresses_(show_addresses), + show_metadata_(show_metadata), profile_(profile), filter_(std::move(filter)) {} @@ -351,6 +352,7 @@ class HloDotDumper { ColorScheme GetInstructionColor(const HloInstruction* instr); string GetInstructionNodeShape(const HloInstruction* instr); string GetInstructionNodeLabel(const HloInstruction* instr); + string GetInstructionNodeMetadata(const HloInstruction* instr); string GetInstructionNodeExtraInfo(const HloInstruction* instr); string GetInstructionNodeInlinedConstants(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr); @@ -363,6 +365,7 @@ class HloDotDumper { const HloComputation* computation_; // never null const string label_; // overall name for the graph const bool show_addresses_; + const bool show_metadata_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -621,6 +624,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { ColorScheme color = GetInstructionColor(instr); string node_shape = GetInstructionNodeShape(instr); string node_label = GetInstructionNodeLabel(instr); + string node_metadata = GetInstructionNodeMetadata(instr); string extra_info = GetInstructionNodeExtraInfo(instr); string inlined_constants = GetInstructionNodeInlinedConstants(instr); string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); @@ -638,7 +642,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { // Build the text that will be displayed inside the node. string node_body = node_label; for (const string& s : - {trivial_subcomputation, extra_info, inlined_constants}) { + {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) { if (!s.empty()) { StrAppend(&node_body, "
", s); } @@ -806,6 +810,16 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { HtmlLikeStringSanitize(instr->name())); } +string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { + if (!show_metadata_ || instr->metadata().op_name().empty()) { + return ""; + } + + return Printf(R"(%s
op_type: %s)", + HtmlLikeStringSanitize(instr->metadata().op_name()), + HtmlLikeStringSanitize(instr->metadata().op_type())); +} + string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { string opcode_specific_info = [&]() -> string { switch (instr->opcode()) { @@ -1135,11 +1149,11 @@ string DumpGraph(const HloComputation& computation, const string& label, graph_url = FileGraphRenderer().RenderGraph( graph, GraphRendererInterface::TF_GRAPHDEF, debug_options); } else { - graph = - HloDotDumper(&computation, label, - /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), - hlo_execution_profile, NodeFilter()) - .Dump(); + graph = HloDotDumper( + &computation, label, + /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), + /*show_metadata=*/false, hlo_execution_profile, NodeFilter()) + .Dump(); graph_url = GetGraphRenderer()->RenderGraph( graph, GraphRendererInterface::DOT_GRAPH, debug_options); } @@ -1148,7 +1162,8 @@ string DumpGraph(const HloComputation& computation, const string& label, return graph_url; } -string DumpNeighborhoodAround(const HloInstruction& node, int radius) { +string DumpNeighborhoodAround(const HloInstruction& node, int radius, + bool show_metadata) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); @@ -1156,7 +1171,7 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius) { string graph = HloDotDumper(node.parent(), label, /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), - /*profile=*/nullptr, filter) + show_metadata, /*profile=*/nullptr, filter) .Dump(); return GetGraphRenderer()->RenderGraph( graph, GraphRendererInterface::DOT_GRAPH, debug_options); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 0100d50c050..a17ede7f0a0 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -62,7 +62,8 @@ string DumpGraph(const HloComputation& computation, const string& label, // The number of nodes dumped is controlled by the radius parameter, which // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. -string DumpNeighborhoodAround(const HloInstruction& node, int radius); +string DumpNeighborhoodAround(const HloInstruction& node, int radius, + bool show_metadata = false); // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label.