[XLA] Optionally add metadata lines to graph neighborhood dumps.

PiperOrigin-RevId: 167911962
This commit is contained in:
A. Unique TensorFlower 2017-09-07 14:28:29 -07:00 committed by TensorFlower Gardener
parent 0575c60ac8
commit 2494aa452b
2 changed files with 27 additions and 11 deletions

View File

@ -312,11 +312,12 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
class HloDotDumper {
public:
HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
bool show_addresses, const HloExecutionProfile* profile,
NodeFilter filter)
bool show_addresses, bool show_metadata,
const HloExecutionProfile* profile, NodeFilter filter)
: computation_(computation),
label_(label.ToString()),
show_addresses_(show_addresses),
show_metadata_(show_metadata),
profile_(profile),
filter_(std::move(filter)) {}
@ -351,6 +352,7 @@ class HloDotDumper {
ColorScheme GetInstructionColor(const HloInstruction* instr);
string GetInstructionNodeShape(const HloInstruction* instr);
string GetInstructionNodeLabel(const HloInstruction* instr);
string GetInstructionNodeMetadata(const HloInstruction* instr);
string GetInstructionNodeExtraInfo(const HloInstruction* instr);
string GetInstructionNodeInlinedConstants(const HloInstruction* instr);
void AddInstructionIncomingEdges(const HloInstruction* instr);
@ -363,6 +365,7 @@ class HloDotDumper {
const HloComputation* computation_; // never null
const string label_; // overall name for the graph
const bool show_addresses_;
const bool show_metadata_;
const HloExecutionProfile* profile_; // may be null
const NodeFilter filter_;
@ -621,6 +624,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
ColorScheme color = GetInstructionColor(instr);
string node_shape = GetInstructionNodeShape(instr);
string node_label = GetInstructionNodeLabel(instr);
string node_metadata = GetInstructionNodeMetadata(instr);
string extra_info = GetInstructionNodeExtraInfo(instr);
string inlined_constants = GetInstructionNodeInlinedConstants(instr);
string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
@ -638,7 +642,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
// Build the text that will be displayed inside the node.
string node_body = node_label;
for (const string& s :
{trivial_subcomputation, extra_info, inlined_constants}) {
{trivial_subcomputation, node_metadata, extra_info, inlined_constants}) {
if (!s.empty()) {
StrAppend(&node_body, "<br/>", s);
}
@ -806,6 +810,16 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
HtmlLikeStringSanitize(instr->name()));
}
string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
if (!show_metadata_ || instr->metadata().op_name().empty()) {
return "";
}
return Printf(R"(%s<br/>op_type: %s)",
HtmlLikeStringSanitize(instr->metadata().op_name()),
HtmlLikeStringSanitize(instr->metadata().op_type()));
}
string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
string opcode_specific_info = [&]() -> string {
switch (instr->opcode()) {
@ -1135,11 +1149,11 @@ string DumpGraph(const HloComputation& computation, const string& label,
graph_url = FileGraphRenderer().RenderGraph(
graph, GraphRendererInterface::TF_GRAPHDEF, debug_options);
} else {
graph =
HloDotDumper(&computation, label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
hlo_execution_profile, NodeFilter())
.Dump();
graph = HloDotDumper(
&computation, label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
/*show_metadata=*/false, hlo_execution_profile, NodeFilter())
.Dump();
graph_url = GetGraphRenderer()->RenderGraph(
graph, GraphRendererInterface::DOT_GRAPH, debug_options);
}
@ -1148,7 +1162,8 @@ string DumpGraph(const HloComputation& computation, const string& label,
return graph_url;
}
string DumpNeighborhoodAround(const HloInstruction& node, int radius) {
string DumpNeighborhoodAround(const HloInstruction& node, int radius,
bool show_metadata) {
auto debug_options = node.GetModule()->config().debug_options();
string label =
StrCat("Neighborhood of ", radius, " nodes around ", node.name());
@ -1156,7 +1171,7 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius) {
string graph =
HloDotDumper(node.parent(), label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
/*profile=*/nullptr, filter)
show_metadata, /*profile=*/nullptr, filter)
.Dump();
return GetGraphRenderer()->RenderGraph(
graph, GraphRendererInterface::DOT_GRAPH, debug_options);

View File

@ -62,7 +62,8 @@ string DumpGraph(const HloComputation& computation, const string& label,
// The number of nodes dumped is controlled by the radius parameter, which
// (roughly) corresponds to the max distance a node may be from the primary node
// before it's omitted from the graph.
string DumpNeighborhoodAround(const HloInstruction& node, int radius);
string DumpNeighborhoodAround(const HloInstruction& node, int radius,
bool show_metadata = false);
// Dumps the HloModule::ToString() as a file into the provided directory path
// suffixed with the provided label.