Add a show_fusion_subcomputations command to interactive_graphviz

Hiding fusion subcomputations is useful when we want to only investigate the
connectivity of the computation that contains the fusion instructions.

PiperOrigin-RevId: 313101238
Change-Id: I25e9cfb5857d0cc90e07f45cfa1617fc6d378558
This commit is contained in:
A. Unique TensorFlower 2020-05-25 13:53:10 -07:00 committed by TensorFlower Gardener
parent a7ed5a542e
commit 55c1176fe2
3 changed files with 27 additions and 56 deletions

View File

@ -312,13 +312,12 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
class HloDotDumper {
public:
HloDotDumper(const HloComputation* computation, absl::string_view label,
const DebugOptions& debug_options,
HloRenderOptions hlo_render_options,
const DebugOptions& debug_options, bool show_backend_config,
const HloExecutionProfile* profile, NodeFilter filter)
: computation_(computation),
label_(label),
debug_options_(debug_options),
hlo_render_options_(hlo_render_options),
show_backend_config_(show_backend_config),
profile_(profile),
filter_(std::move(filter)) {}
@ -385,7 +384,7 @@ class HloDotDumper {
const HloComputation* computation_; // never null
const string label_; // overall name for the graph
const DebugOptions& debug_options_;
const HloRenderOptions hlo_render_options_;
const bool show_backend_config_;
const HloExecutionProfile* profile_; // may be null
const NodeFilter filter_;
@ -566,8 +565,7 @@ bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
if (subcomp->IsFusionComputation()) {
const HloInstruction* fusion = subcomp->FusionInstruction();
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) ||
!hlo_render_options_.show_fusion_subcomputations) {
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
return false;
}
}
@ -1135,8 +1133,7 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
string HloDotDumper::GetInstructionNodeBackendConfig(
const HloInstruction* instr) {
if (!hlo_render_options_.show_backend_config ||
instr->raw_backend_config_string().empty()) {
if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
return "";
}
@ -1607,14 +1604,14 @@ StatusOr<string> RenderGraph(const HloComputation& computation,
const DebugOptions& debug_options,
RenderedGraphFormat format,
const HloExecutionProfile* hlo_execution_profile,
HloRenderOptions hlo_render_options) {
bool show_backend_config) {
tensorflow::mutex_lock lock(url_renderer_mu);
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
return Unavailable("Can't render as URL; no URL renderer was registered.");
}
string rendered_dot =
HloDotDumper(&computation, label, debug_options, hlo_render_options,
HloDotDumper(&computation, label, debug_options, show_backend_config,
hlo_execution_profile, NodeFilter())
.Dump();
return WrapDotInFormat(rendered_dot, format);
@ -1622,7 +1619,7 @@ StatusOr<string> RenderGraph(const HloComputation& computation,
StatusOr<string> RenderNeighborhoodAround(
const HloInstruction& node, int radius, RenderedGraphFormat format,
HloRenderOptions hlo_render_options,
bool show_backend_config,
const absl::flat_hash_set<const HloInstruction*>& boundary) {
tensorflow::mutex_lock lock(url_renderer_mu);
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
@ -1635,7 +1632,7 @@ StatusOr<string> RenderNeighborhoodAround(
string rendered_dot =
HloDotDumper(node.parent(), label,
node.GetModule()->config().debug_options(),
hlo_render_options, /*profile=*/nullptr,
show_backend_config, /*profile=*/nullptr,
MakeNodeRadiusAroundFilter(&node, radius, boundary))
.Dump();
return WrapDotInFormat(rendered_dot, format);
@ -1644,7 +1641,7 @@ StatusOr<string> RenderNeighborhoodAround(
StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
const HloInstruction& to, int64 max_nodes,
RenderedGraphFormat format,
HloRenderOptions hlo_render_options) {
bool show_backend_config) {
tensorflow::mutex_lock lock(url_renderer_mu);
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
return FailedPrecondition(
@ -1666,7 +1663,7 @@ StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
"NODES***<br/><br/>");
}
string rendered_dot =
HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
HloDotDumper(from.parent(), label, debug_options, show_backend_config,
/*profile=*/nullptr, filter)
.Dump();
return WrapDotInFormat(rendered_dot, format);

View File

@ -50,14 +50,6 @@ enum class RenderedGraphFormat {
kUrl,
};
struct HloRenderOptions {
// Include the backend config string in the rendered graph.
bool show_backend_config = false;
// Include the fusion subcomputations in the rendered graph.
bool show_fusion_subcomputations = true;
};
// Renders an HLO module as a human-readable visual graph.
//
// Note that this only works well for relatively small graphs (no more than a
@ -69,7 +61,7 @@ StatusOr<string> RenderGraph(
const HloComputation& computation, absl::string_view label,
const DebugOptions& debug_options, RenderedGraphFormat format,
const HloExecutionProfile* hlo_execution_profile = nullptr,
HloRenderOptions hlo_render_options = {});
bool show_backend_config = false);
// Like RenderGraph, but renders only nodes "near" the given node in the graph.
//
@ -81,7 +73,7 @@ StatusOr<string> RenderGraph(
// will be omitted even if they are within the radius.
StatusOr<string> RenderNeighborhoodAround(
const HloInstruction& node, int radius, RenderedGraphFormat format,
HloRenderOptions hlo_render_options = {},
bool show_backend_config = false,
const absl::flat_hash_set<const HloInstruction*>& boundary = {});
// Renders nodes on any of the paths from `from` to `to`. If there are more
@ -90,7 +82,7 @@ StatusOr<string> RenderNeighborhoodAround(
StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
const HloInstruction& to, int64 max_nodes,
RenderedGraphFormat format,
HloRenderOptions hlo_render_options = {});
bool show_backend_config = false);
// Registers a function which implements RenderedGraphFormat::kUrl.
//

View File

@ -112,7 +112,8 @@ constexpr int64 kDefaultMaxNumNodesInAllPaths = 100;
using absl::EqualsIgnoreCase;
HloRenderOptions hlo_render_options;
// A global control for whether backend configuration display is enabled.
bool show_backend_config = true;
HloInstruction* FindInstruction(const HloModule& module, string node_name) {
if (absl::StartsWith(node_name, "%")) {
@ -159,8 +160,6 @@ void DoHelpCommand() {
Renders all nodes in <computation>.
backend_config [on|off]
Controls whether backend operation configuration information is printed.
show_fusion_subcomputations [on|off]
Controls whether fusion subcomputations are shown.
list [name|op_name|op_type] <pattern>
Lists all instructions whose name, metadata op_name, or metadata op_type
contains <pattern> as a substring.
@ -183,32 +182,15 @@ void DoHelpCommand() {
// Turn metadata-printing on or off.
void DoBackendConfigCommand(const std::vector<string>& tokens) {
if (tokens.size() == 2 && tokens[1] == "on") {
hlo_render_options.show_backend_config = true;
show_backend_config = true;
} else if (tokens.size() == 2 && tokens[1] == "off") {
hlo_render_options.show_backend_config = false;
show_backend_config = false;
} else if (tokens.size() != 1) {
std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)"
<< std::endl;
}
std::cout << "Backend configuration display "
<< (hlo_render_options.show_backend_config ? "ON" : "OFF")
<< std::endl;
}
// Turn fusion computation display on or off.
void DoShowFusionSubcomputationsCommand(const std::vector<string>& tokens) {
if (tokens.size() == 2 && tokens[1] == "on") {
hlo_render_options.show_fusion_subcomputations = true;
} else if (tokens.size() == 2 && tokens[1] == "off") {
hlo_render_options.show_fusion_subcomputations = false;
} else if (tokens.size() != 1) {
std::cerr << "(Illegal show_fusion_subcomputations value. Use either "
"'on' or 'off'.)"
<< std::endl;
}
std::cout << "Fusion subcomputations display "
<< (hlo_render_options.show_fusion_subcomputations ? "ON" : "OFF")
<< std::endl;
<< (show_backend_config ? "ON" : "OFF") << std::endl;
}
// List all computations in the module.
@ -391,7 +373,7 @@ void DoExtractCommand(const HloModule& module,
auto extracted_module = ExtractModule(instr, height);
std::cout << extracted_module->ToString(
HloPrintOptions::ShortParsable().set_print_backend_config(
hlo_render_options.show_backend_config))
show_backend_config))
<< std::endl;
}
@ -535,7 +517,7 @@ void DoAllPathsCommand(const Options& opts, const HloModule& module,
}
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
return RenderAllPathsFromTo(*from, *to, max_nodes, format,
hlo_render_options);
/*show_backend_config=*/show_backend_config);
});
}
@ -600,13 +582,15 @@ void DoPlotCommand(const Options& opts, const HloModule& module,
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
return RenderGraph(*comp, /*label=*/"",
comp->parent()->config().debug_options(), format,
/*hlo_execution_profile=*/nullptr, hlo_render_options);
/*hlo_execution_profile=*/nullptr,
/*show_backend_config=*/show_backend_config);
});
} else {
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
return RenderNeighborhoodAround(*instr, graph_width, format,
hlo_render_options,
/*boundary=*/boundary);
return RenderNeighborhoodAround(
*instr, graph_width, format,
/*show_backend_config=*/show_backend_config,
/*boundary=*/boundary);
});
}
}
@ -633,8 +617,6 @@ void InteractiveDumpGraphs(const Options& opts, const HloModule& module) {
DoHelpCommand();
} else if (tokens[0] == "backend_config") {
DoBackendConfigCommand(tokens);
} else if (tokens[0] == "show_fusion_subcomputations") {
DoShowFusionSubcomputationsCommand(tokens);
} else if (tokens[0] == "list") {
if (tokens.size() > 1 && tokens[1] == "computations") {
DoListComputationsCommand(module, tokens);