[XLA] Optionally add metadata lines to graph neighborhood dumps.

PiperOrigin-RevId: 167911962
This commit is contained in:
A. Unique TensorFlower 2017-09-07 14:28:29 -07:00 committed by TensorFlower Gardener
parent 0575c60ac8
commit 2494aa452b
2 changed files with 27 additions and 11 deletions

View File

@ -312,11 +312,12 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
class HloDotDumper { class HloDotDumper {
public: public:
HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
bool show_addresses, const HloExecutionProfile* profile, bool show_addresses, bool show_metadata,
NodeFilter filter) const HloExecutionProfile* profile, NodeFilter filter)
: computation_(computation), : computation_(computation),
label_(label.ToString()), label_(label.ToString()),
show_addresses_(show_addresses), show_addresses_(show_addresses),
show_metadata_(show_metadata),
profile_(profile), profile_(profile),
filter_(std::move(filter)) {} filter_(std::move(filter)) {}
@ -351,6 +352,7 @@ class HloDotDumper {
ColorScheme GetInstructionColor(const HloInstruction* instr); ColorScheme GetInstructionColor(const HloInstruction* instr);
string GetInstructionNodeShape(const HloInstruction* instr); string GetInstructionNodeShape(const HloInstruction* instr);
string GetInstructionNodeLabel(const HloInstruction* instr); string GetInstructionNodeLabel(const HloInstruction* instr);
string GetInstructionNodeMetadata(const HloInstruction* instr);
string GetInstructionNodeExtraInfo(const HloInstruction* instr); string GetInstructionNodeExtraInfo(const HloInstruction* instr);
string GetInstructionNodeInlinedConstants(const HloInstruction* instr); string GetInstructionNodeInlinedConstants(const HloInstruction* instr);
void AddInstructionIncomingEdges(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr);
@ -363,6 +365,7 @@ class HloDotDumper {
const HloComputation* computation_; // never null const HloComputation* computation_; // never null
const string label_; // overall name for the graph const string label_; // overall name for the graph
const bool show_addresses_; const bool show_addresses_;
const bool show_metadata_;
const HloExecutionProfile* profile_; // may be null const HloExecutionProfile* profile_; // may be null
const NodeFilter filter_; const NodeFilter filter_;
@ -621,6 +624,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
ColorScheme color = GetInstructionColor(instr); ColorScheme color = GetInstructionColor(instr);
string node_shape = GetInstructionNodeShape(instr); string node_shape = GetInstructionNodeShape(instr);
string node_label = GetInstructionNodeLabel(instr); string node_label = GetInstructionNodeLabel(instr);
string node_metadata = GetInstructionNodeMetadata(instr);
string extra_info = GetInstructionNodeExtraInfo(instr); string extra_info = GetInstructionNodeExtraInfo(instr);
string inlined_constants = GetInstructionNodeInlinedConstants(instr); string inlined_constants = GetInstructionNodeInlinedConstants(instr);
string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
@ -638,7 +642,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
// Build the text that will be displayed inside the node. // Build the text that will be displayed inside the node.
string node_body = node_label; string node_body = node_label;
for (const string& s : for (const string& s :
{trivial_subcomputation, extra_info, inlined_constants}) { {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) {
if (!s.empty()) { if (!s.empty()) {
StrAppend(&node_body, "<br/>", s); StrAppend(&node_body, "<br/>", s);
} }
@ -806,6 +810,16 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
HtmlLikeStringSanitize(instr->name())); HtmlLikeStringSanitize(instr->name()));
} }
string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
if (!show_metadata_ || instr->metadata().op_name().empty()) {
return "";
}
return Printf(R"(%s<br/>op_type: %s)",
HtmlLikeStringSanitize(instr->metadata().op_name()),
HtmlLikeStringSanitize(instr->metadata().op_type()));
}
string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
string opcode_specific_info = [&]() -> string { string opcode_specific_info = [&]() -> string {
switch (instr->opcode()) { switch (instr->opcode()) {
@ -1135,11 +1149,11 @@ string DumpGraph(const HloComputation& computation, const string& label,
graph_url = FileGraphRenderer().RenderGraph( graph_url = FileGraphRenderer().RenderGraph(
graph, GraphRendererInterface::TF_GRAPHDEF, debug_options); graph, GraphRendererInterface::TF_GRAPHDEF, debug_options);
} else { } else {
graph = graph = HloDotDumper(
HloDotDumper(&computation, label, &computation, label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(), /*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
hlo_execution_profile, NodeFilter()) /*show_metadata=*/false, hlo_execution_profile, NodeFilter())
.Dump(); .Dump();
graph_url = GetGraphRenderer()->RenderGraph( graph_url = GetGraphRenderer()->RenderGraph(
graph, GraphRendererInterface::DOT_GRAPH, debug_options); graph, GraphRendererInterface::DOT_GRAPH, debug_options);
} }
@ -1148,7 +1162,8 @@ string DumpGraph(const HloComputation& computation, const string& label,
return graph_url; return graph_url;
} }
string DumpNeighborhoodAround(const HloInstruction& node, int radius) { string DumpNeighborhoodAround(const HloInstruction& node, int radius,
bool show_metadata) {
auto debug_options = node.GetModule()->config().debug_options(); auto debug_options = node.GetModule()->config().debug_options();
string label = string label =
StrCat("Neighborhood of ", radius, " nodes around ", node.name()); StrCat("Neighborhood of ", radius, " nodes around ", node.name());
@ -1156,7 +1171,7 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius) {
string graph = string graph =
HloDotDumper(node.parent(), label, HloDotDumper(node.parent(), label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(), /*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
/*profile=*/nullptr, filter) show_metadata, /*profile=*/nullptr, filter)
.Dump(); .Dump();
return GetGraphRenderer()->RenderGraph( return GetGraphRenderer()->RenderGraph(
graph, GraphRendererInterface::DOT_GRAPH, debug_options); graph, GraphRendererInterface::DOT_GRAPH, debug_options);

View File

@ -62,7 +62,8 @@ string DumpGraph(const HloComputation& computation, const string& label,
// The number of nodes dumped is controlled by the radius parameter, which // The number of nodes dumped is controlled by the radius parameter, which
// (roughly) corresponds to the max distance a node may be from the primary node // (roughly) corresponds to the max distance a node may be from the primary node
// before it's omitted from the graph. // before it's omitted from the graph.
string DumpNeighborhoodAround(const HloInstruction& node, int radius); string DumpNeighborhoodAround(const HloInstruction& node, int radius,
bool show_metadata = false);
// Dumps the HloModule::ToString() as a file into the provided directory path // Dumps the HloModule::ToString() as a file into the provided directory path
// suffixed with the provided label. // suffixed with the provided label.