Add a show_fusion_subcomputations command to interactive_graphviz
Hiding fusion subcomputations is useful when we want to only investigate the connectivity of the computation that contains the fusion instructions. PiperOrigin-RevId: 313101238 Change-Id: I25e9cfb5857d0cc90e07f45cfa1617fc6d378558
This commit is contained in:
parent
a7ed5a542e
commit
55c1176fe2
|
@ -312,13 +312,12 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
|
|||
class HloDotDumper {
|
||||
public:
|
||||
HloDotDumper(const HloComputation* computation, absl::string_view label,
|
||||
const DebugOptions& debug_options,
|
||||
HloRenderOptions hlo_render_options,
|
||||
const DebugOptions& debug_options, bool show_backend_config,
|
||||
const HloExecutionProfile* profile, NodeFilter filter)
|
||||
: computation_(computation),
|
||||
label_(label),
|
||||
debug_options_(debug_options),
|
||||
hlo_render_options_(hlo_render_options),
|
||||
show_backend_config_(show_backend_config),
|
||||
profile_(profile),
|
||||
filter_(std::move(filter)) {}
|
||||
|
||||
|
@ -385,7 +384,7 @@ class HloDotDumper {
|
|||
const HloComputation* computation_; // never null
|
||||
const string label_; // overall name for the graph
|
||||
const DebugOptions& debug_options_;
|
||||
const HloRenderOptions hlo_render_options_;
|
||||
const bool show_backend_config_;
|
||||
const HloExecutionProfile* profile_; // may be null
|
||||
const NodeFilter filter_;
|
||||
|
||||
|
@ -566,8 +565,7 @@ bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
|
|||
bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
|
||||
if (subcomp->IsFusionComputation()) {
|
||||
const HloInstruction* fusion = subcomp->FusionInstruction();
|
||||
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) ||
|
||||
!hlo_render_options_.show_fusion_subcomputations) {
|
||||
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -1135,8 +1133,7 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
|
|||
|
||||
string HloDotDumper::GetInstructionNodeBackendConfig(
|
||||
const HloInstruction* instr) {
|
||||
if (!hlo_render_options_.show_backend_config ||
|
||||
instr->raw_backend_config_string().empty()) {
|
||||
if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
|
@ -1607,14 +1604,14 @@ StatusOr<string> RenderGraph(const HloComputation& computation,
|
|||
const DebugOptions& debug_options,
|
||||
RenderedGraphFormat format,
|
||||
const HloExecutionProfile* hlo_execution_profile,
|
||||
HloRenderOptions hlo_render_options) {
|
||||
bool show_backend_config) {
|
||||
tensorflow::mutex_lock lock(url_renderer_mu);
|
||||
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
|
||||
return Unavailable("Can't render as URL; no URL renderer was registered.");
|
||||
}
|
||||
|
||||
string rendered_dot =
|
||||
HloDotDumper(&computation, label, debug_options, hlo_render_options,
|
||||
HloDotDumper(&computation, label, debug_options, show_backend_config,
|
||||
hlo_execution_profile, NodeFilter())
|
||||
.Dump();
|
||||
return WrapDotInFormat(rendered_dot, format);
|
||||
|
@ -1622,7 +1619,7 @@ StatusOr<string> RenderGraph(const HloComputation& computation,
|
|||
|
||||
StatusOr<string> RenderNeighborhoodAround(
|
||||
const HloInstruction& node, int radius, RenderedGraphFormat format,
|
||||
HloRenderOptions hlo_render_options,
|
||||
bool show_backend_config,
|
||||
const absl::flat_hash_set<const HloInstruction*>& boundary) {
|
||||
tensorflow::mutex_lock lock(url_renderer_mu);
|
||||
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
|
||||
|
@ -1635,7 +1632,7 @@ StatusOr<string> RenderNeighborhoodAround(
|
|||
string rendered_dot =
|
||||
HloDotDumper(node.parent(), label,
|
||||
node.GetModule()->config().debug_options(),
|
||||
hlo_render_options, /*profile=*/nullptr,
|
||||
show_backend_config, /*profile=*/nullptr,
|
||||
MakeNodeRadiusAroundFilter(&node, radius, boundary))
|
||||
.Dump();
|
||||
return WrapDotInFormat(rendered_dot, format);
|
||||
|
@ -1644,7 +1641,7 @@ StatusOr<string> RenderNeighborhoodAround(
|
|||
StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
|
||||
const HloInstruction& to, int64 max_nodes,
|
||||
RenderedGraphFormat format,
|
||||
HloRenderOptions hlo_render_options) {
|
||||
bool show_backend_config) {
|
||||
tensorflow::mutex_lock lock(url_renderer_mu);
|
||||
if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
|
||||
return FailedPrecondition(
|
||||
|
@ -1666,7 +1663,7 @@ StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
|
|||
"NODES***<br/><br/>");
|
||||
}
|
||||
string rendered_dot =
|
||||
HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
|
||||
HloDotDumper(from.parent(), label, debug_options, show_backend_config,
|
||||
/*profile=*/nullptr, filter)
|
||||
.Dump();
|
||||
return WrapDotInFormat(rendered_dot, format);
|
||||
|
|
|
@ -50,14 +50,6 @@ enum class RenderedGraphFormat {
|
|||
kUrl,
|
||||
};
|
||||
|
||||
struct HloRenderOptions {
|
||||
// Include the backend config string in the rendered graph.
|
||||
bool show_backend_config = false;
|
||||
|
||||
// Include the fusion subcomputations in the rendered graph.
|
||||
bool show_fusion_subcomputations = true;
|
||||
};
|
||||
|
||||
// Renders an HLO module as a human-readable visual graph.
|
||||
//
|
||||
// Note that this only works well for relatively small graphs (no more than a
|
||||
|
@ -69,7 +61,7 @@ StatusOr<string> RenderGraph(
|
|||
const HloComputation& computation, absl::string_view label,
|
||||
const DebugOptions& debug_options, RenderedGraphFormat format,
|
||||
const HloExecutionProfile* hlo_execution_profile = nullptr,
|
||||
HloRenderOptions hlo_render_options = {});
|
||||
bool show_backend_config = false);
|
||||
|
||||
// Like RenderGraph, but renders only nodes "near" the given node in the graph.
|
||||
//
|
||||
|
@ -81,7 +73,7 @@ StatusOr<string> RenderGraph(
|
|||
// will be omitted even if they are within the radius.
|
||||
StatusOr<string> RenderNeighborhoodAround(
|
||||
const HloInstruction& node, int radius, RenderedGraphFormat format,
|
||||
HloRenderOptions hlo_render_options = {},
|
||||
bool show_backend_config = false,
|
||||
const absl::flat_hash_set<const HloInstruction*>& boundary = {});
|
||||
|
||||
// Renders nodes on any of the paths from `from` to `to`. If there are more
|
||||
|
@ -90,7 +82,7 @@ StatusOr<string> RenderNeighborhoodAround(
|
|||
StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
|
||||
const HloInstruction& to, int64 max_nodes,
|
||||
RenderedGraphFormat format,
|
||||
HloRenderOptions hlo_render_options = {});
|
||||
bool show_backend_config = false);
|
||||
|
||||
// Registers a function which implements RenderedGraphFormat::kUrl.
|
||||
//
|
||||
|
|
|
@ -112,7 +112,8 @@ constexpr int64 kDefaultMaxNumNodesInAllPaths = 100;
|
|||
|
||||
using absl::EqualsIgnoreCase;
|
||||
|
||||
HloRenderOptions hlo_render_options;
|
||||
// A global control for whether backend configuration display is enabled.
|
||||
bool show_backend_config = true;
|
||||
|
||||
HloInstruction* FindInstruction(const HloModule& module, string node_name) {
|
||||
if (absl::StartsWith(node_name, "%")) {
|
||||
|
@ -159,8 +160,6 @@ void DoHelpCommand() {
|
|||
Renders all nodes in <computation>.
|
||||
backend_config [on|off]
|
||||
Controls whether backend operation configuration information is printed.
|
||||
show_fusion_subcomputations [on|off]
|
||||
Controls whether fusion subcomputations are shown.
|
||||
list [name|op_name|op_type] <pattern>
|
||||
Lists all instructions whose name, metadata op_name, or metadata op_type
|
||||
contains <pattern> as a substring.
|
||||
|
@ -183,32 +182,15 @@ void DoHelpCommand() {
|
|||
// Turn metadata-printing on or off.
|
||||
void DoBackendConfigCommand(const std::vector<string>& tokens) {
|
||||
if (tokens.size() == 2 && tokens[1] == "on") {
|
||||
hlo_render_options.show_backend_config = true;
|
||||
show_backend_config = true;
|
||||
} else if (tokens.size() == 2 && tokens[1] == "off") {
|
||||
hlo_render_options.show_backend_config = false;
|
||||
show_backend_config = false;
|
||||
} else if (tokens.size() != 1) {
|
||||
std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)"
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "Backend configuration display "
|
||||
<< (hlo_render_options.show_backend_config ? "ON" : "OFF")
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Turn fusion computation display on or off.
|
||||
void DoShowFusionSubcomputationsCommand(const std::vector<string>& tokens) {
|
||||
if (tokens.size() == 2 && tokens[1] == "on") {
|
||||
hlo_render_options.show_fusion_subcomputations = true;
|
||||
} else if (tokens.size() == 2 && tokens[1] == "off") {
|
||||
hlo_render_options.show_fusion_subcomputations = false;
|
||||
} else if (tokens.size() != 1) {
|
||||
std::cerr << "(Illegal show_fusion_subcomputations value. Use either "
|
||||
"'on' or 'off'.)"
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "Fusion subcomputations display "
|
||||
<< (hlo_render_options.show_fusion_subcomputations ? "ON" : "OFF")
|
||||
<< std::endl;
|
||||
<< (show_backend_config ? "ON" : "OFF") << std::endl;
|
||||
}
|
||||
|
||||
// List all computations in the module.
|
||||
|
@ -391,7 +373,7 @@ void DoExtractCommand(const HloModule& module,
|
|||
auto extracted_module = ExtractModule(instr, height);
|
||||
std::cout << extracted_module->ToString(
|
||||
HloPrintOptions::ShortParsable().set_print_backend_config(
|
||||
hlo_render_options.show_backend_config))
|
||||
show_backend_config))
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
|
@ -535,7 +517,7 @@ void DoAllPathsCommand(const Options& opts, const HloModule& module,
|
|||
}
|
||||
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
|
||||
return RenderAllPathsFromTo(*from, *to, max_nodes, format,
|
||||
hlo_render_options);
|
||||
/*show_backend_config=*/show_backend_config);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -600,13 +582,15 @@ void DoPlotCommand(const Options& opts, const HloModule& module,
|
|||
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
|
||||
return RenderGraph(*comp, /*label=*/"",
|
||||
comp->parent()->config().debug_options(), format,
|
||||
/*hlo_execution_profile=*/nullptr, hlo_render_options);
|
||||
/*hlo_execution_profile=*/nullptr,
|
||||
/*show_backend_config=*/show_backend_config);
|
||||
});
|
||||
} else {
|
||||
RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
|
||||
return RenderNeighborhoodAround(*instr, graph_width, format,
|
||||
hlo_render_options,
|
||||
/*boundary=*/boundary);
|
||||
return RenderNeighborhoodAround(
|
||||
*instr, graph_width, format,
|
||||
/*show_backend_config=*/show_backend_config,
|
||||
/*boundary=*/boundary);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -633,8 +617,6 @@ void InteractiveDumpGraphs(const Options& opts, const HloModule& module) {
|
|||
DoHelpCommand();
|
||||
} else if (tokens[0] == "backend_config") {
|
||||
DoBackendConfigCommand(tokens);
|
||||
} else if (tokens[0] == "show_fusion_subcomputations") {
|
||||
DoShowFusionSubcomputationsCommand(tokens);
|
||||
} else if (tokens[0] == "list") {
|
||||
if (tokens.size() > 1 && tokens[1] == "computations") {
|
||||
DoListComputationsCommand(module, tokens);
|
||||
|
|
Loading…
Reference in New Issue