[XLA] Implement the fusion progress visualizer, which dumps out the HTML+JS page visualizing the fusion decisions for XLA:GPU
Requires the graphviz URL renderer to be registered. PiperOrigin-RevId: 353765943 Change-Id: I1994c2f8abaf15b3da6e914b351c26993f5765f5
This commit is contained in:
parent
588f5a7f60
commit
d92917b6d6
@ -40,6 +40,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
|
||||
opts.set_xla_gpu_asm_extra_flags("");
|
||||
opts.set_xla_eliminate_hlo_implicit_broadcast(true);
|
||||
opts.set_xla_dump_hlo_as_html(false);
|
||||
opts.set_xla_dump_fusion_visualization(false);
|
||||
opts.set_xla_dump_include_timestamp(true);
|
||||
opts.set_xla_dump_max_hlo_modules(-1);
|
||||
opts.set_xla_dump_module_metadata(false);
|
||||
@ -483,6 +484,15 @@ static void AllocateFlags() {
|
||||
"directory specified by --xla_dump_to). This is not implemented by "
|
||||
"default; you need to add a plugin which calls "
|
||||
"RegisterGraphToURLRenderer()."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_dump_fusion_visualization",
|
||||
bool_setter_for(&DebugOptions::set_xla_dump_fusion_visualization),
|
||||
flag_values->xla_dump_fusion_visualization(),
|
||||
"Tries to generate HLO fusion visualization as an HTML page to the "
|
||||
"directory specified by --xla_dump_to). This is not implemented by "
|
||||
"default; you need to add a plugin which calls "
|
||||
"RegisterGraphToURLRenderer(). Generates a file per computation. "
|
||||
"Currently only implemented for the GPU backend."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_dump_hlo_snapshots",
|
||||
bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots),
|
||||
|
@ -1655,6 +1655,7 @@ cc_library(
|
||||
deps = [
|
||||
":fusion_queue",
|
||||
":hlo",
|
||||
":hlo_graph_dumper",
|
||||
":hlo_pass",
|
||||
":hlo_reachability",
|
||||
":pattern_matcher",
|
||||
|
@ -43,6 +43,7 @@ struct CanonicalDebugOptions {
|
||||
dump_as_dot(opts.xla_dump_hlo_as_dot()),
|
||||
dump_as_html(opts.xla_dump_hlo_as_html()),
|
||||
dump_as_url(opts.xla_dump_hlo_as_url()),
|
||||
dump_fusion_visualization(opts.xla_dump_fusion_visualization()),
|
||||
dump_snapshots(opts.xla_dump_hlo_snapshots()),
|
||||
dump_include_timestamp(opts.xla_dump_include_timestamp()),
|
||||
dump_max_hlo_modules(opts.xla_dump_max_hlo_modules()),
|
||||
@ -135,6 +136,7 @@ struct CanonicalDebugOptions {
|
||||
bool dump_as_dot;
|
||||
bool dump_as_html;
|
||||
bool dump_as_url;
|
||||
bool dump_fusion_visualization;
|
||||
bool dump_snapshots;
|
||||
bool dump_include_timestamp;
|
||||
int64 dump_max_hlo_modules;
|
||||
@ -268,6 +270,24 @@ std::vector<std::string> DumpHloModuleImpl(const HloModule& module,
|
||||
render_graph(RenderedGraphFormat::kHtml), opts));
|
||||
}
|
||||
|
||||
if (opts.dump_fusion_visualization) {
|
||||
for (const HloComputation* computation :
|
||||
module.MakeNonfusionComputations()) {
|
||||
StatusOr<string> rendered_graph = RenderGraph(
|
||||
*computation,
|
||||
/*label=*/absl::StrCat(filename, "_", computation->name()),
|
||||
module.config().debug_options(),
|
||||
RenderedGraphFormat::kFusionVisualization, profile);
|
||||
file_paths.push_back(DumpToFileInDirImpl(
|
||||
StrFormat("%s_%s_fusion_visualization.html", filename,
|
||||
computation->name()),
|
||||
rendered_graph.ok() ? *rendered_graph
|
||||
: StrFormat("Error rendering graph: %s",
|
||||
rendered_graph.status().ToString()),
|
||||
opts));
|
||||
}
|
||||
}
|
||||
|
||||
// Special case for rendering graphs as URLs. We'll dump them to a file
|
||||
// because why not, but we always log them to stdout as well.
|
||||
if (opts.dump_as_url) {
|
||||
|
@ -1113,6 +1113,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_reachability",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||
@ -1209,6 +1210,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -314,7 +315,15 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
||||
<< " }";
|
||||
// Remove 'fusion' instruction.
|
||||
CHECK_EQ(0, fusion->user_count()) << fusion->ToString();
|
||||
return computation_->RemoveInstruction(fusion);
|
||||
TF_RETURN_IF_ERROR(computation_->RemoveInstruction(fusion));
|
||||
if (computation_->parent()
|
||||
->config()
|
||||
.debug_options()
|
||||
.xla_dump_fusion_visualization()) {
|
||||
TF_RETURN_IF_ERROR(RegisterFusionState(*computation_, "fusion merger"));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<bool> FusionMerger::Run(HloModule* module) {
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
|
||||
@ -242,12 +243,23 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool GpuMultiOutputFusion::DoMultiOutputFusion() {
|
||||
StatusOr<bool> GpuMultiOutputFusion::DoMultiOutputFusion() {
|
||||
bool changed = false;
|
||||
RecomputeReachability();
|
||||
std::vector<HloInstruction*> defs_before_uses =
|
||||
computation_->MakeInstructionPostOrder();
|
||||
|
||||
auto dump_fusion_state = [&] {
|
||||
if (computation_->parent()
|
||||
->config()
|
||||
.debug_options()
|
||||
.xla_dump_fusion_visualization()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RegisterFusionState(*computation_, "GpuMultiOutputFusion"));
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
while (!defs_before_uses.empty()) {
|
||||
// Traverse the HLO in uses-before-defs order by removing instruction from
|
||||
// the back of the vector.
|
||||
@ -290,6 +302,8 @@ bool GpuMultiOutputFusion::DoMultiOutputFusion() {
|
||||
CHECK_EQ(0, producer->user_count());
|
||||
TF_CHECK_OK(computation_->RemoveInstruction(producer));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(dump_fusion_state());
|
||||
RecomputeReachability();
|
||||
continue;
|
||||
}
|
||||
@ -309,6 +323,8 @@ bool GpuMultiOutputFusion::DoMultiOutputFusion() {
|
||||
CHECK_EQ(0, producer->user_count());
|
||||
TF_CHECK_OK(computation_->RemoveInstruction(producer));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(dump_fusion_state());
|
||||
RecomputeReachability();
|
||||
}
|
||||
return changed;
|
||||
@ -318,7 +334,8 @@ StatusOr<bool> GpuMultiOutputFusion::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
for (auto* computation : module->MakeNonfusionComputations()) {
|
||||
computation_ = computation;
|
||||
if (DoMultiOutputFusion()) {
|
||||
TF_ASSIGN_OR_RETURN(bool fusion_changed, DoMultiOutputFusion());
|
||||
if (fusion_changed) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
@ -98,7 +98,7 @@ class GpuMultiOutputFusion : public HloModulePass {
|
||||
private:
|
||||
bool FuseSiblings(HloInstruction* parent);
|
||||
|
||||
bool DoMultiOutputFusion();
|
||||
StatusOr<bool> DoMultiOutputFusion();
|
||||
|
||||
// Recompute reachability for the current computation.
|
||||
void RecomputeReachability();
|
||||
|
@ -1577,13 +1577,115 @@ tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED);
|
||||
std::function<StatusOr<string>(absl::string_view)>* url_renderer
|
||||
TF_GUARDED_BY(url_renderer_mu) = nullptr;
|
||||
|
||||
// Precondition: url_renderer != nullptr.
|
||||
// Storage for fusion visualization: (module_id, computation_id) -> sequence of
|
||||
// dot dumps.
|
||||
tensorflow::mutex fusion_visualizer_state_mu(tensorflow::LINKER_INITIALIZED);
|
||||
static auto& fusion_visualizer_state TF_GUARDED_BY(fusion_visualizer_state_mu) =
|
||||
*new absl::flat_hash_map<std::pair<int64, int64>,
|
||||
std::vector<std::string>>();
|
||||
|
||||
// Generates a key to the fusion visualizer state mapping.
|
||||
std::pair<int, int> FusionVisualizerStateKey(
|
||||
const HloComputation& computation) {
|
||||
return std::make_pair(computation.parent()->unique_id(),
|
||||
computation.unique_id());
|
||||
}
|
||||
|
||||
// Generates a fusion explorer for the given computation using the data in
|
||||
// fusion_visualizer_state and the URL renderer. Precondition: url_renderer !=
|
||||
// nullptr.
|
||||
StatusOr<std::string> WrapFusionExplorer(const HloComputation& computation)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
|
||||
CHECK(url_renderer != nullptr);
|
||||
tensorflow::mutex_lock lock(fusion_visualizer_state_mu);
|
||||
const std::vector<std::string>& dot_graphs =
|
||||
fusion_visualizer_state[FusionVisualizerStateKey(computation)];
|
||||
std::vector<std::string> dot_urls;
|
||||
dot_urls.reserve(dot_graphs.size());
|
||||
for (const std::string& dot : dot_graphs) {
|
||||
TF_ASSIGN_OR_RETURN(std::string url, (*url_renderer)(dot));
|
||||
dot_urls.push_back(url);
|
||||
}
|
||||
|
||||
return absl::StrReplaceAll(
|
||||
R"(
|
||||
<!doctype html>
|
||||
<style>
|
||||
html, body {height: 100%; text-align: center;}
|
||||
#display {height: 80%; width: 80%;}
|
||||
</style>
|
||||
<title>Fusion Explorer: $TITLE</title>
|
||||
<iframe id='display' width=80% height=80%></iframe>
|
||||
<p id='description'></p>
|
||||
<p>
|
||||
<a id='prev' href='#'>Prev Step</a>
|
||||
<a id='next' href='#'>Next Step</a>
|
||||
</p>
|
||||
<p>
|
||||
Use j/k for keyboard navigation.
|
||||
</p>
|
||||
<script>
|
||||
var currId = -1;
|
||||
var urls = [$URLS];
|
||||
|
||||
var setIframe = function() {
|
||||
document.getElementById('display').src = urls[currId];
|
||||
};
|
||||
|
||||
var update = function(delta) {
|
||||
currId = (currId + delta + urls.length) % urls.length;
|
||||
document.getElementById('description').innerHTML = "Frame #"
|
||||
+ (currId + 1) + " / " + urls.length;
|
||||
setIframe();
|
||||
};
|
||||
|
||||
document.getElementById('prev').onclick = function() {
|
||||
update(-1);
|
||||
return false;
|
||||
};
|
||||
|
||||
document.getElementById('next').onclick = function() {
|
||||
update(1);
|
||||
return false;
|
||||
};
|
||||
|
||||
window.addEventListener("keydown", function (event) {
|
||||
if (event.defaultPrevented) {
|
||||
return;
|
||||
}
|
||||
if (event.key == "j") {
|
||||
update(1);
|
||||
} else if (event.key == "k") {
|
||||
update(-1);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
event.preventDefault();
|
||||
}, true);
|
||||
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
update(1);
|
||||
});
|
||||
|
||||
</script>
|
||||
)",
|
||||
{{"$URLS", absl::StrJoin(dot_urls, ", ",
|
||||
[&](std::string* out, const std::string& url) {
|
||||
absl::StrAppend(out, "\"", url, "\"");
|
||||
})},
|
||||
{"$TITLE",
|
||||
absl::StrCat(computation.parent()->name(), "_", computation.name())}});
|
||||
}
|
||||
|
||||
// Precondition: (url_renderer != nullptr || (format != kUrl
|
||||
// && format != kFusionVisualization)).
|
||||
//
|
||||
// (We specify this as a precondition rather than checking it in here and
|
||||
// returning an error because we want to fail quickly when there's no URL
|
||||
// renderer available, and this function runs only after we've done all the work
|
||||
// of producing dot for the graph.)
|
||||
StatusOr<string> WrapDotInFormat(absl::string_view dot,
|
||||
StatusOr<string> WrapDotInFormat(const HloComputation& computation,
|
||||
absl::string_view dot,
|
||||
RenderedGraphFormat format)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
|
||||
switch (format) {
|
||||
@ -1595,6 +1697,8 @@ StatusOr<string> WrapDotInFormat(absl::string_view dot,
|
||||
return WrapDotInHtml(dot);
|
||||
case RenderedGraphFormat::kDot:
|
||||
return string(dot);
|
||||
case RenderedGraphFormat::kFusionVisualization:
|
||||
return WrapFusionExplorer(computation);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1613,6 +1717,25 @@ void RegisterGraphToURLRenderer(
|
||||
std::move(renderer));
|
||||
}
|
||||
|
||||
Status RegisterFusionState(const HloComputation& computation,
|
||||
absl::string_view label) {
|
||||
tensorflow::mutex_lock lock(fusion_visualizer_state_mu);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
string dot_graph,
|
||||
RenderGraph(computation,
|
||||
absl::StrCat(computation.parent()->name(), ", ",
|
||||
computation.name(), ", ", label),
|
||||
/*debug_options=*/{}, xla::RenderedGraphFormat::kDot,
|
||||
/*hlo_execution_profile=*/nullptr,
|
||||
/*hlo_render_options=*/{}));
|
||||
std::vector<std::string>& fusion_states =
|
||||
fusion_visualizer_state[FusionVisualizerStateKey(computation)];
|
||||
if (fusion_states.empty() || fusion_states.back() != dot_graph) {
|
||||
fusion_states.push_back(dot_graph);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<string> RenderGraph(const HloComputation& computation,
|
||||
absl::string_view label,
|
||||
const DebugOptions& debug_options,
|
||||
@ -1628,7 +1751,7 @@ StatusOr<string> RenderGraph(const HloComputation& computation,
|
||||
HloDotDumper(&computation, label, debug_options, hlo_render_options,
|
||||
hlo_execution_profile, NodeFilter())
|
||||
.Dump();
|
||||
return WrapDotInFormat(rendered_dot, format);
|
||||
return WrapDotInFormat(computation, rendered_dot, format);
|
||||
}
|
||||
|
||||
StatusOr<string> RenderNeighborhoodAround(
|
||||
@ -1649,7 +1772,7 @@ StatusOr<string> RenderNeighborhoodAround(
|
||||
hlo_render_options, /*profile=*/nullptr,
|
||||
MakeNodeRadiusAroundFilter(&node, radius, boundary))
|
||||
.Dump();
|
||||
return WrapDotInFormat(rendered_dot, format);
|
||||
return WrapDotInFormat(*node.parent(), rendered_dot, format);
|
||||
}
|
||||
|
||||
StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
|
||||
@ -1680,7 +1803,7 @@ StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
|
||||
HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
|
||||
/*profile=*/nullptr, filter)
|
||||
.Dump();
|
||||
return WrapDotInFormat(rendered_dot, format);
|
||||
return WrapDotInFormat(*from.parent(), rendered_dot, format);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
// human-readable graphical format.
|
||||
//
|
||||
// Fundamentally all graphs are rendered using the DOT language, but they can be
|
||||
// packaged three different ways:
|
||||
// packaged four different ways:
|
||||
//
|
||||
// - as a raw DOT file, which can be rendered using `graphviz`.
|
||||
//
|
||||
@ -36,7 +36,9 @@ limitations under the License.
|
||||
//
|
||||
// - as a URL hosted somewhere which somehow embeds the DOT file.
|
||||
//
|
||||
// This last option is not implemented by default, but you can add a plugin to
|
||||
// - as an HTML page showing the fusion progress.
|
||||
//
|
||||
// Two last options are not implemented by default, but you can add a plugin to
|
||||
// implement it via RegisterGraphToURLRenderer.
|
||||
//
|
||||
// TODO(jlebar): Rename this file to hlo_graph_renderer.
|
||||
@ -48,6 +50,7 @@ enum class RenderedGraphFormat {
|
||||
kDot,
|
||||
kHtml,
|
||||
kUrl,
|
||||
kFusionVisualization,
|
||||
};
|
||||
|
||||
struct HloRenderOptions {
|
||||
@ -92,6 +95,11 @@ StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
|
||||
RenderedGraphFormat format,
|
||||
HloRenderOptions hlo_render_options = {});
|
||||
|
||||
// Registers the fusion state of the graph for future visualization using
|
||||
// the kFusionVisulization render format.
|
||||
Status RegisterFusionState(const HloComputation& computation,
|
||||
absl::string_view label);
|
||||
|
||||
// Registers a function which implements RenderedGraphFormat::kUrl.
|
||||
//
|
||||
// The input to the function is dot, and the output should be a URL or an error.
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/fusion_queue.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
|
||||
@ -577,6 +578,12 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (module->config().debug_options().xla_dump_fusion_visualization()) {
|
||||
TF_RETURN_IF_ERROR(RegisterFusionState(
|
||||
*computation,
|
||||
absl::StrCat("InstructionFusion, may_duplicate=", may_duplicate_)));
|
||||
}
|
||||
}
|
||||
|
||||
if (config_collection_mode_ != FusionConfigCollection::kOff) {
|
||||
|
@ -250,6 +250,9 @@ message DebugOptions {
|
||||
// Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML)
|
||||
bool xla_dump_hlo_as_html = 116;
|
||||
|
||||
// Dump the visualization of the fusion progress.
|
||||
bool xla_dump_fusion_visualization = 149;
|
||||
|
||||
// If true, every time an HLO module is run, we will dump an HloSnapshot
|
||||
// (essentially, a serialized module plus its inputs) to the --xla_dump_to
|
||||
// directory.
|
||||
@ -311,7 +314,7 @@ message DebugOptions {
|
||||
// Compilation errors out if these ops are encountered.
|
||||
bool xla_gpu_deterministic_ops = 148;
|
||||
|
||||
// Next id: 149
|
||||
// Next id: 150
|
||||
|
||||
// Extra options to pass to the compilation backend (e.g. LLVM); specific
|
||||
// interpretation of these values is left to the backend.
|
||||
|
Loading…
Reference in New Issue
Block a user