Roll forward "Add a show_fusion_subcomputations command to interactive_graphviz" with fix

PiperOrigin-RevId: 313426932
Change-Id: Ia2366ee899d7bd0d69448144d1c18164d5801753
This commit is contained in:
Sanjoy Das 2020-05-27 11:20:16 -07:00 committed by TensorFlower Gardener
parent a5622fee57
commit dc18758c27
3 changed files with 56 additions and 27 deletions

View File

@ -312,12 +312,13 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
class HloDotDumper {
public:
HloDotDumper(const HloComputation* computation, absl::string_view label,
const DebugOptions& debug_options, bool show_backend_config,
const DebugOptions& debug_options,
HloRenderOptions hlo_render_options,
const HloExecutionProfile* profile, NodeFilter filter)
: computation_(computation),
label_(label),
debug_options_(debug_options),
show_backend_config_(show_backend_config),
hlo_render_options_(hlo_render_options),
profile_(profile),
filter_(std::move(filter)) {}
@ -384,7 +385,7 @@ class HloDotDumper {
const HloComputation* computation_; // never null
const string label_; // overall name for the graph
const DebugOptions& debug_options_;
const bool show_backend_config_;
const HloRenderOptions hlo_render_options_;
const HloExecutionProfile* profile_; // may be null
const NodeFilter filter_;
@ -565,7 +566,8 @@ bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
if (subcomp->IsFusionComputation()) {
const HloInstruction* fusion = subcomp->FusionInstruction();
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) ||
!hlo_render_options_.show_fusion_subcomputations) {
return false;
}
}
@ -1133,7 +1135,8 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
string HloDotDumper::GetInstructionNodeBackendConfig(
const HloInstruction* instr) {
if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
if (!hlo_render_options_.show_backend_config ||
instr->raw_backend_config_string().empty()) {
return "";
}
@ -1604,14 +1607,14 @@ StatusOr<string> RenderGraph(const HloComputation& computation,
const DebugOptions& debug_options,
RenderedGraphFormat format,
const HloExecutionProfile* hlo_execution_profile,
bool show_backend_config) {
HloRenderOptions hlo_render_options) {
tensorflow::mutex_lock lock(url_renderer_mu);
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
return Unavailable("Can't render as URL; no URL renderer was registered.");
}
string rendered_dot =
HloDotDumper(&computation, label, debug_options, show_backend_config,
HloDotDumper(&computation, label, debug_options, hlo_render_options,
hlo_execution_profile, NodeFilter())
.Dump();
return WrapDotInFormat(rendered_dot, format);
@ -1619,7 +1622,7 @@ StatusOr<string> RenderGraph(const HloComputation& computation,
StatusOr<string> RenderNeighborhoodAround(
const HloInstruction& node, int radius, RenderedGraphFormat format,
bool show_backend_config,
HloRenderOptions hlo_render_options,
const absl::flat_hash_set<const HloInstruction*>& boundary) {
tensorflow::mutex_lock lock(url_renderer_mu);
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
@ -1632,7 +1635,7 @@ StatusOr<string> RenderNeighborhoodAround(
string rendered_dot =
HloDotDumper(node.parent(), label,
node.GetModule()->config().debug_options(),
show_backend_config, /*profile=*/nullptr,
hlo_render_options, /*profile=*/nullptr,
MakeNodeRadiusAroundFilter(&node, radius, boundary))
.Dump();
return WrapDotInFormat(rendered_dot, format);
@ -1641,7 +1644,7 @@ StatusOr<string> RenderNeighborhoodAround(
StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
const HloInstruction& to, int64 max_nodes,
RenderedGraphFormat format,
bool show_backend_config) {
HloRenderOptions hlo_render_options) {
tensorflow::mutex_lock lock(url_renderer_mu);
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
return FailedPrecondition(
@ -1663,7 +1666,7 @@ StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
"NODES***<br/><br/>");
}
string rendered_dot =
HloDotDumper(from.parent(), label, debug_options, show_backend_config,
HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
/*profile=*/nullptr, filter)
.Dump();
return WrapDotInFormat(rendered_dot, format);

View File

@ -50,6 +50,14 @@ enum class RenderedGraphFormat {
kUrl,
};
struct HloRenderOptions {
// Include the backend config string in the rendered graph.
bool show_backend_config = false;
// Include the fusion subcomputations in the rendered graph.
bool show_fusion_subcomputations = true;
};
// Renders an HLO module as a human-readable visual graph.
//
// Note that this only works well for relatively small graphs (no more than a
@ -61,7 +69,7 @@ StatusOr<string> RenderGraph(
const HloComputation& computation, absl::string_view label,
const DebugOptions& debug_options, RenderedGraphFormat format,
const HloExecutionProfile* hlo_execution_profile = nullptr,
bool show_backend_config = false);
HloRenderOptions hlo_render_options = {});
// Like RenderGraph, but renders only nodes "near" the given node in the graph.
//
@ -73,7 +81,7 @@ StatusOr<string> RenderGraph(
// will be omitted even if they are within the radius.
StatusOr<string> RenderNeighborhoodAround(
const HloInstruction& node, int radius, RenderedGraphFormat format,
bool show_backend_config = false,
HloRenderOptions hlo_render_options = {},
const absl::flat_hash_set<const HloInstruction*>& boundary = {});
// Renders nodes on any of the paths from `from` to `to`. If there are more
@ -82,7 +90,7 @@ StatusOr<string> RenderNeighborhoodAround(
StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
const HloInstruction& to, int64 max_nodes,
RenderedGraphFormat format,
bool show_backend_config = false);
HloRenderOptions hlo_render_options = {});
// Registers a function which implements RenderedGraphFormat::kUrl.
//

View File

@ -112,8 +112,7 @@ constexpr int64 kDefaultMaxNumNodesInAllPaths = 100;
using absl::EqualsIgnoreCase;
// A global control for whether backend configuration display is enabled.
bool show_backend_config = true;
HloRenderOptions hlo_render_options;
HloInstruction* FindInstruction(const HloModule& module, string node_name) {
if (absl::StartsWith(node_name, "%")) {
@ -160,6 +159,8 @@ void DoHelpCommand() {
Renders all nodes in <computation>.
backend_config [on|off]
Controls whether backend operation configuration information is printed.
show_fusion_subcomputations [on|off]
Controls whether fusion subcomputations are shown.
list [name|op_name|op_type] <pattern>
Lists all instructions whose name, metadata op_name, or metadata op_type
contains <pattern> as a substring.
@ -182,15 +183,32 @@ void DoHelpCommand() {
// Turn metadata-printing on or off.
void DoBackendConfigCommand(const std::vector<string>& tokens) {
if (tokens.size() == 2 && tokens[1] == "on") {
show_backend_config = true;
hlo_render_options.show_backend_config = true;
} else if (tokens.size() == 2 && tokens[1] == "off") {
show_backend_config = false;
hlo_render_options.show_backend_config = false;
} else if (tokens.size() != 1) {
std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)"
<< std::endl;
}
std::cout << "Backend configuration display "
<< (show_backend_config ? "ON" : "OFF") << std::endl;
<< (hlo_render_options.show_backend_config ? "ON" : "OFF")
<< std::endl;
}
// Turn fusion computation display on or off.
void DoShowFusionSubcomputationsCommand(const std::vector<string>& tokens) {
if (tokens.size() == 2 && tokens[1] == "on") {
hlo_render_options.show_fusion_subcomputations = true;
} else if (tokens.size() == 2 && tokens[1] == "off") {
hlo_render_options.show_fusion_subcomputations = false;
} else if (tokens.size() != 1) {
std::cerr << "(Illegal show_fusion_subcomputations value. Use either "
"'on' or 'off'.)"
<< std::endl;
}
std::cout << "Fusion subcomputations display "
<< (hlo_render_options.show_fusion_subcomputations ? "ON" : "OFF")
<< std::endl;
}
// List all computations in the module.
@ -373,7 +391,7 @@ void DoExtractCommand(const HloModule& module,
auto extracted_module = ExtractModule(instr, height);
std::cout << extracted_module->ToString(
HloPrintOptions::ShortParsable().set_print_backend_config(
show_backend_config))
hlo_render_options.show_backend_config))
<< std::endl;
}
@ -517,7 +535,7 @@ void DoAllPathsCommand(const Options& opts, const HloModule& module,
}
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
return RenderAllPathsFromTo(*from, *to, max_nodes, format,
/*show_backend_config=*/show_backend_config);
hlo_render_options);
});
}
@ -582,15 +600,13 @@ void DoPlotCommand(const Options& opts, const HloModule& module,
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
return RenderGraph(*comp, /*label=*/"",
comp->parent()->config().debug_options(), format,
/*hlo_execution_profile=*/nullptr,
/*show_backend_config=*/show_backend_config);
/*hlo_execution_profile=*/nullptr, hlo_render_options);
});
} else {
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
return RenderNeighborhoodAround(
*instr, graph_width, format,
/*show_backend_config=*/show_backend_config,
/*boundary=*/boundary);
return RenderNeighborhoodAround(*instr, graph_width, format,
hlo_render_options,
/*boundary=*/boundary);
});
}
}
@ -617,6 +633,8 @@ void InteractiveDumpGraphs(const Options& opts, const HloModule& module) {
DoHelpCommand();
} else if (tokens[0] == "backend_config") {
DoBackendConfigCommand(tokens);
} else if (tokens[0] == "show_fusion_subcomputations") {
DoShowFusionSubcomputationsCommand(tokens);
} else if (tokens[0] == "list") {
if (tokens.size() > 1 && tokens[1] == "computations") {
DoListComputationsCommand(module, tokens);