[XLA] Implement the fusion progress visualizer, which dumps out the HTML+JS page visualizing the fusion decisions for XLA:GPU

Requires the graphviz URL renderer to be registered.

PiperOrigin-RevId: 353765943
Change-Id: I1994c2f8abaf15b3da6e914b351c26993f5765f5
This commit is contained in:
George Karpenkov 2021-01-25 17:27:22 -08:00 committed by TensorFlower Gardener
parent 588f5a7f60
commit d92917b6d6
11 changed files with 212 additions and 12 deletions

View File

@ -40,6 +40,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_asm_extra_flags("");
opts.set_xla_eliminate_hlo_implicit_broadcast(true);
opts.set_xla_dump_hlo_as_html(false);
opts.set_xla_dump_fusion_visualization(false);
opts.set_xla_dump_include_timestamp(true);
opts.set_xla_dump_max_hlo_modules(-1);
opts.set_xla_dump_module_metadata(false);
@ -483,6 +484,15 @@ static void AllocateFlags() {
"directory specified by --xla_dump_to). This is not implemented by "
"default; you need to add a plugin which calls "
"RegisterGraphToURLRenderer()."));
flag_objects->push_back(tensorflow::Flag(
"xla_dump_fusion_visualization",
bool_setter_for(&DebugOptions::set_xla_dump_fusion_visualization),
flag_values->xla_dump_fusion_visualization(),
"Tries to generate HLO fusion visualization as an HTML page to the "
"directory specified by --xla_dump_to). This is not implemented by "
"default; you need to add a plugin which calls "
"RegisterGraphToURLRenderer(). Generates a file per computation. "
"Currently only implemented for the GPU backend."));
flag_objects->push_back(tensorflow::Flag(
"xla_dump_hlo_snapshots",
bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots),

View File

@ -1655,6 +1655,7 @@ cc_library(
deps = [
":fusion_queue",
":hlo",
":hlo_graph_dumper",
":hlo_pass",
":hlo_reachability",
":pattern_matcher",

View File

@ -43,6 +43,7 @@ struct CanonicalDebugOptions {
dump_as_dot(opts.xla_dump_hlo_as_dot()),
dump_as_html(opts.xla_dump_hlo_as_html()),
dump_as_url(opts.xla_dump_hlo_as_url()),
dump_fusion_visualization(opts.xla_dump_fusion_visualization()),
dump_snapshots(opts.xla_dump_hlo_snapshots()),
dump_include_timestamp(opts.xla_dump_include_timestamp()),
dump_max_hlo_modules(opts.xla_dump_max_hlo_modules()),
@ -135,6 +136,7 @@ struct CanonicalDebugOptions {
bool dump_as_dot;
bool dump_as_html;
bool dump_as_url;
bool dump_fusion_visualization;
bool dump_snapshots;
bool dump_include_timestamp;
int64 dump_max_hlo_modules;
@ -268,6 +270,24 @@ std::vector<std::string> DumpHloModuleImpl(const HloModule& module,
render_graph(RenderedGraphFormat::kHtml), opts));
}
if (opts.dump_fusion_visualization) {
for (const HloComputation* computation :
module.MakeNonfusionComputations()) {
StatusOr<string> rendered_graph = RenderGraph(
*computation,
/*label=*/absl::StrCat(filename, "_", computation->name()),
module.config().debug_options(),
RenderedGraphFormat::kFusionVisualization, profile);
file_paths.push_back(DumpToFileInDirImpl(
StrFormat("%s_%s_fusion_visualization.html", filename,
computation->name()),
rendered_graph.ok() ? *rendered_graph
: StrFormat("Error rendering graph: %s",
rendered_graph.status().ToString()),
opts));
}
}
// Special case for rendering graphs as URLs. We'll dump them to a file
// because why not, but we always log them to stdout as well.
if (opts.dump_as_url) {

View File

@ -1113,6 +1113,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_reachability",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
@ -1209,6 +1210,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/core:lib",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
@ -314,7 +315,15 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
<< " }";
// Remove 'fusion' instruction.
CHECK_EQ(0, fusion->user_count()) << fusion->ToString();
return computation_->RemoveInstruction(fusion);
TF_RETURN_IF_ERROR(computation_->RemoveInstruction(fusion));
if (computation_->parent()
->config()
.debug_options()
.xla_dump_fusion_visualization()) {
TF_RETURN_IF_ERROR(RegisterFusionState(*computation_, "fusion merger"));
}
return Status::OK();
}
StatusOr<bool> FusionMerger::Run(HloModule* module) {

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
@ -242,12 +243,23 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent) {
return changed;
}
bool GpuMultiOutputFusion::DoMultiOutputFusion() {
StatusOr<bool> GpuMultiOutputFusion::DoMultiOutputFusion() {
bool changed = false;
RecomputeReachability();
std::vector<HloInstruction*> defs_before_uses =
computation_->MakeInstructionPostOrder();
auto dump_fusion_state = [&] {
if (computation_->parent()
->config()
.debug_options()
.xla_dump_fusion_visualization()) {
TF_RETURN_IF_ERROR(
RegisterFusionState(*computation_, "GpuMultiOutputFusion"));
}
return Status::OK();
};
while (!defs_before_uses.empty()) {
// Traverse the HLO in uses-before-defs order by removing instruction from
// the back of the vector.
@ -290,6 +302,8 @@ bool GpuMultiOutputFusion::DoMultiOutputFusion() {
CHECK_EQ(0, producer->user_count());
TF_CHECK_OK(computation_->RemoveInstruction(producer));
}
TF_RETURN_IF_ERROR(dump_fusion_state());
RecomputeReachability();
continue;
}
@ -309,6 +323,8 @@ bool GpuMultiOutputFusion::DoMultiOutputFusion() {
CHECK_EQ(0, producer->user_count());
TF_CHECK_OK(computation_->RemoveInstruction(producer));
}
TF_RETURN_IF_ERROR(dump_fusion_state());
RecomputeReachability();
}
return changed;
@ -318,7 +334,8 @@ StatusOr<bool> GpuMultiOutputFusion::Run(HloModule* module) {
bool changed = false;
for (auto* computation : module->MakeNonfusionComputations()) {
computation_ = computation;
if (DoMultiOutputFusion()) {
TF_ASSIGN_OR_RETURN(bool fusion_changed, DoMultiOutputFusion());
if (fusion_changed) {
changed = true;
}
}

View File

@ -98,7 +98,7 @@ class GpuMultiOutputFusion : public HloModulePass {
private:
bool FuseSiblings(HloInstruction* parent);
bool DoMultiOutputFusion();
StatusOr<bool> DoMultiOutputFusion();
// Recompute reachability for the current computation.
void RecomputeReachability();

View File

@ -1577,13 +1577,115 @@ tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED);
std::function<StatusOr<string>(absl::string_view)>* url_renderer
TF_GUARDED_BY(url_renderer_mu) = nullptr;
// Precondition: url_renderer != nullptr.
// Storage for fusion visualization: (module_id, computation_id) -> sequence of
// dot dumps.
tensorflow::mutex fusion_visualizer_state_mu(tensorflow::LINKER_INITIALIZED);
static auto& fusion_visualizer_state TF_GUARDED_BY(fusion_visualizer_state_mu) =
*new absl::flat_hash_map<std::pair<int64, int64>,
std::vector<std::string>>();
// Generates a key to the fusion visualizer state mapping.
std::pair<int, int> FusionVisualizerStateKey(
const HloComputation& computation) {
return std::make_pair(computation.parent()->unique_id(),
computation.unique_id());
}
// Generates a fusion explorer for the given computation using the data in
// fusion_visualizer_state and the URL renderer. Precondition: url_renderer !=
// nullptr.
StatusOr<std::string> WrapFusionExplorer(const HloComputation& computation)
TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
CHECK(url_renderer != nullptr);
tensorflow::mutex_lock lock(fusion_visualizer_state_mu);
const std::vector<std::string>& dot_graphs =
fusion_visualizer_state[FusionVisualizerStateKey(computation)];
std::vector<std::string> dot_urls;
dot_urls.reserve(dot_graphs.size());
for (const std::string& dot : dot_graphs) {
TF_ASSIGN_OR_RETURN(std::string url, (*url_renderer)(dot));
dot_urls.push_back(url);
}
return absl::StrReplaceAll(
R"(
<!doctype html>
<style>
html, body {height: 100%; text-align: center;}
#display {height: 80%; width: 80%;}
</style>
<title>Fusion Explorer: $TITLE</title>
<iframe id='display' width=80% height=80%></iframe>
<p id='description'></p>
<p>
<a id='prev' href='#'>Prev Step</a>
<a id='next' href='#'>Next Step</a>
</p>
<p>
Use j/k for keyboard navigation.
</p>
<script>
var currId = -1;
var urls = [$URLS];
var setIframe = function() {
document.getElementById('display').src = urls[currId];
};
var update = function(delta) {
currId = (currId + delta + urls.length) % urls.length;
document.getElementById('description').innerHTML = "Frame #"
+ (currId + 1) + " / " + urls.length;
setIframe();
};
document.getElementById('prev').onclick = function() {
update(-1);
return false;
};
document.getElementById('next').onclick = function() {
update(1);
return false;
};
window.addEventListener("keydown", function (event) {
if (event.defaultPrevented) {
return;
}
if (event.key == "j") {
update(1);
} else if (event.key == "k") {
update(-1);
} else {
return;
}
event.preventDefault();
}, true);
document.addEventListener("DOMContentLoaded", function() {
update(1);
});
</script>
)",
{{"$URLS", absl::StrJoin(dot_urls, ", ",
[&](std::string* out, const std::string& url) {
absl::StrAppend(out, "\"", url, "\"");
})},
{"$TITLE",
absl::StrCat(computation.parent()->name(), "_", computation.name())}});
}
// Precondition: (url_renderer != nullptr || (format != kUrl
// && format != kFusionVisualization)).
//
// (We specify this as a precondition rather than checking it in here and
// returning an error because we want to fail quickly when there's no URL
// renderer available, and this function runs only after we've done all the work
// of producing dot for the graph.)
StatusOr<string> WrapDotInFormat(absl::string_view dot,
StatusOr<string> WrapDotInFormat(const HloComputation& computation,
absl::string_view dot,
RenderedGraphFormat format)
TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
switch (format) {
@ -1595,6 +1697,8 @@ StatusOr<string> WrapDotInFormat(absl::string_view dot,
return WrapDotInHtml(dot);
case RenderedGraphFormat::kDot:
return string(dot);
case RenderedGraphFormat::kFusionVisualization:
return WrapFusionExplorer(computation);
}
}
@ -1613,6 +1717,25 @@ void RegisterGraphToURLRenderer(
std::move(renderer));
}
Status RegisterFusionState(const HloComputation& computation,
absl::string_view label) {
tensorflow::mutex_lock lock(fusion_visualizer_state_mu);
TF_ASSIGN_OR_RETURN(
string dot_graph,
RenderGraph(computation,
absl::StrCat(computation.parent()->name(), ", ",
computation.name(), ", ", label),
/*debug_options=*/{}, xla::RenderedGraphFormat::kDot,
/*hlo_execution_profile=*/nullptr,
/*hlo_render_options=*/{}));
std::vector<std::string>& fusion_states =
fusion_visualizer_state[FusionVisualizerStateKey(computation)];
if (fusion_states.empty() || fusion_states.back() != dot_graph) {
fusion_states.push_back(dot_graph);
}
return Status::OK();
}
StatusOr<string> RenderGraph(const HloComputation& computation,
absl::string_view label,
const DebugOptions& debug_options,
@ -1628,7 +1751,7 @@ StatusOr<string> RenderGraph(const HloComputation& computation,
HloDotDumper(&computation, label, debug_options, hlo_render_options,
hlo_execution_profile, NodeFilter())
.Dump();
return WrapDotInFormat(rendered_dot, format);
return WrapDotInFormat(computation, rendered_dot, format);
}
StatusOr<string> RenderNeighborhoodAround(
@ -1649,7 +1772,7 @@ StatusOr<string> RenderNeighborhoodAround(
hlo_render_options, /*profile=*/nullptr,
MakeNodeRadiusAroundFilter(&node, radius, boundary))
.Dump();
return WrapDotInFormat(rendered_dot, format);
return WrapDotInFormat(*node.parent(), rendered_dot, format);
}
StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
@ -1680,7 +1803,7 @@ StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
/*profile=*/nullptr, filter)
.Dump();
return WrapDotInFormat(rendered_dot, format);
return WrapDotInFormat(*from.parent(), rendered_dot, format);
}
} // namespace xla

View File

@ -27,7 +27,7 @@ limitations under the License.
// human-readable graphical format.
//
// Fundamentally all graphs are rendered using the DOT language, but they can be
// packaged three different ways:
// packaged four different ways:
//
// - as a raw DOT file, which can be rendered using `graphviz`.
//
@ -36,7 +36,9 @@ limitations under the License.
//
// - as a URL hosted somewhere which somehow embeds the DOT file.
//
// This last option is not implemented by default, but you can add a plugin to
// - as an HTML page showing the fusion progress.
//
// Two last options are not implemented by default, but you can add a plugin to
// implement it via RegisterGraphToURLRenderer.
//
// TODO(jlebar): Rename this file to hlo_graph_renderer.
@ -48,6 +50,7 @@ enum class RenderedGraphFormat {
kDot,
kHtml,
kUrl,
kFusionVisualization,
};
struct HloRenderOptions {
@ -92,6 +95,11 @@ StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
RenderedGraphFormat format,
HloRenderOptions hlo_render_options = {});
// Registers the fusion state of the graph for future visualization using
// the kFusionVisulization render format.
Status RegisterFusionState(const HloComputation& computation,
absl::string_view label);
// Registers a function which implements RenderedGraphFormat::kUrl.
//
// The input to the function is dot, and the output should be a URL or an error.

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/fusion_queue.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
@ -577,6 +578,12 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
}
break;
}
if (module->config().debug_options().xla_dump_fusion_visualization()) {
TF_RETURN_IF_ERROR(RegisterFusionState(
*computation,
absl::StrCat("InstructionFusion, may_duplicate=", may_duplicate_)));
}
}
if (config_collection_mode_ != FusionConfigCollection::kOff) {

View File

@ -250,6 +250,9 @@ message DebugOptions {
// Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML)
bool xla_dump_hlo_as_html = 116;
// Dump the visualization of the fusion progress.
bool xla_dump_fusion_visualization = 149;
// If true, every time an HLO module is run, we will dump an HloSnapshot
// (essentially, a serialized module plus its inputs) to the --xla_dump_to
// directory.
@ -311,7 +314,7 @@ message DebugOptions {
// Compilation errors out if these ops are encountered.
bool xla_gpu_deterministic_ops = 148;
// Next id: 149
// Next id: 150
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.