diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 20040265f44..5486002606d 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -40,6 +40,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_asm_extra_flags(""); opts.set_xla_eliminate_hlo_implicit_broadcast(true); opts.set_xla_dump_hlo_as_html(false); + opts.set_xla_dump_fusion_visualization(false); opts.set_xla_dump_include_timestamp(true); opts.set_xla_dump_max_hlo_modules(-1); opts.set_xla_dump_module_metadata(false); @@ -483,6 +484,15 @@ static void AllocateFlags() { "directory specified by --xla_dump_to). This is not implemented by " "default; you need to add a plugin which calls " "RegisterGraphToURLRenderer().")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_fusion_visualization", + bool_setter_for(&DebugOptions::set_xla_dump_fusion_visualization), + flag_values->xla_dump_fusion_visualization(), + "Tries to generate HLO fusion visualization as an HTML page to the " + "directory specified by --xla_dump_to). This is not implemented by " + "default; you need to add a plugin which calls " + "RegisterGraphToURLRenderer(). Generates a file per computation. " + "Currently only implemented for the GPU backend.")); flag_objects->push_back(tensorflow::Flag( "xla_dump_hlo_snapshots", bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots), diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index b595f33640c..5ab184644a2 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1655,6 +1655,7 @@ cc_library( deps = [ ":fusion_queue", ":hlo", + ":hlo_graph_dumper", ":hlo_pass", ":hlo_reachability", ":pattern_matcher", diff --git a/tensorflow/compiler/xla/service/dump.cc b/tensorflow/compiler/xla/service/dump.cc index 9c2aa7fe4d0..cf1ad28d8f7 100644 --- a/tensorflow/compiler/xla/service/dump.cc +++ b/tensorflow/compiler/xla/service/dump.cc @@ -43,6 +43,7 @@ struct CanonicalDebugOptions { dump_as_dot(opts.xla_dump_hlo_as_dot()), dump_as_html(opts.xla_dump_hlo_as_html()), dump_as_url(opts.xla_dump_hlo_as_url()), + dump_fusion_visualization(opts.xla_dump_fusion_visualization()), dump_snapshots(opts.xla_dump_hlo_snapshots()), dump_include_timestamp(opts.xla_dump_include_timestamp()), dump_max_hlo_modules(opts.xla_dump_max_hlo_modules()), @@ -135,6 +136,7 @@ struct CanonicalDebugOptions { bool dump_as_dot; bool dump_as_html; bool dump_as_url; + bool dump_fusion_visualization; bool dump_snapshots; bool dump_include_timestamp; int64 dump_max_hlo_modules; @@ -268,6 +270,24 @@ std::vector DumpHloModuleImpl(const HloModule& module, render_graph(RenderedGraphFormat::kHtml), opts)); } + if (opts.dump_fusion_visualization) { + for (const HloComputation* computation : + module.MakeNonfusionComputations()) { + StatusOr rendered_graph = RenderGraph( + *computation, + /*label=*/absl::StrCat(filename, "_", computation->name()), + module.config().debug_options(), + RenderedGraphFormat::kFusionVisualization, profile); + file_paths.push_back(DumpToFileInDirImpl( + StrFormat("%s_%s_fusion_visualization.html", filename, + computation->name()), + rendered_graph.ok() ? *rendered_graph + : StrFormat("Error rendering graph: %s", + rendered_graph.status().ToString()), + opts)); + } + } + // Special case for rendering graphs as URLs. We'll dump them to a file // because why not, but we always log them to stdout as well. if (opts.dump_as_url) { diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d7c1831594d..8742aac5e2a 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1113,6 +1113,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", @@ -1209,6 +1210,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index e1f4977e73c..0f9c6397ffe 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -314,7 +315,15 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { << " }"; // Remove 'fusion' instruction. CHECK_EQ(0, fusion->user_count()) << fusion->ToString(); - return computation_->RemoveInstruction(fusion); + TF_RETURN_IF_ERROR(computation_->RemoveInstruction(fusion)); + if (computation_->parent() + ->config() + .debug_options() + .xla_dump_fusion_visualization()) { + TF_RETURN_IF_ERROR(RegisterFusionState(*computation_, "fusion merger")); + } + + return Status::OK(); } StatusOr FusionMerger::Run(HloModule* module) { diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index ad087d4b262..7ac2d982a26 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" @@ -242,12 +243,23 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent) { return changed; } -bool GpuMultiOutputFusion::DoMultiOutputFusion() { +StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { bool changed = false; RecomputeReachability(); std::vector defs_before_uses = computation_->MakeInstructionPostOrder(); + auto dump_fusion_state = [&] { + if (computation_->parent() + ->config() + .debug_options() + .xla_dump_fusion_visualization()) { + TF_RETURN_IF_ERROR( + RegisterFusionState(*computation_, "GpuMultiOutputFusion")); + } + return Status::OK(); + }; + while (!defs_before_uses.empty()) { // Traverse the HLO in uses-before-defs order by removing instruction from // the back of the vector. @@ -290,6 +302,8 @@ bool GpuMultiOutputFusion::DoMultiOutputFusion() { CHECK_EQ(0, producer->user_count()); TF_CHECK_OK(computation_->RemoveInstruction(producer)); } + + TF_RETURN_IF_ERROR(dump_fusion_state()); RecomputeReachability(); continue; } @@ -309,6 +323,8 @@ bool GpuMultiOutputFusion::DoMultiOutputFusion() { CHECK_EQ(0, producer->user_count()); TF_CHECK_OK(computation_->RemoveInstruction(producer)); } + + TF_RETURN_IF_ERROR(dump_fusion_state()); RecomputeReachability(); } return changed; @@ -318,7 +334,8 @@ StatusOr GpuMultiOutputFusion::Run(HloModule* module) { bool changed = false; for (auto* computation : module->MakeNonfusionComputations()) { computation_ = computation; - if (DoMultiOutputFusion()) { + TF_ASSIGN_OR_RETURN(bool fusion_changed, DoMultiOutputFusion()); + if (fusion_changed) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h index c715d31fe48..78cda6c4ed9 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -98,7 +98,7 @@ class GpuMultiOutputFusion : public HloModulePass { private: bool FuseSiblings(HloInstruction* parent); - bool DoMultiOutputFusion(); + StatusOr DoMultiOutputFusion(); // Recompute reachability for the current computation. void RecomputeReachability(); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 533e6b11160..f71cf059c1f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1577,13 +1577,115 @@ tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED); std::function(absl::string_view)>* url_renderer TF_GUARDED_BY(url_renderer_mu) = nullptr; -// Precondition: url_renderer != nullptr. +// Storage for fusion visualization: (module_id, computation_id) -> sequence of +// dot dumps. +tensorflow::mutex fusion_visualizer_state_mu(tensorflow::LINKER_INITIALIZED); +static auto& fusion_visualizer_state TF_GUARDED_BY(fusion_visualizer_state_mu) = + *new absl::flat_hash_map, + std::vector>(); + +// Generates a key to the fusion visualizer state mapping. +std::pair FusionVisualizerStateKey( + const HloComputation& computation) { + return std::make_pair(computation.parent()->unique_id(), + computation.unique_id()); +} + +// Generates a fusion explorer for the given computation using the data in +// fusion_visualizer_state and the URL renderer. Precondition: url_renderer != +// nullptr. +StatusOr WrapFusionExplorer(const HloComputation& computation) + TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) { + CHECK(url_renderer != nullptr); + tensorflow::mutex_lock lock(fusion_visualizer_state_mu); + const std::vector& dot_graphs = + fusion_visualizer_state[FusionVisualizerStateKey(computation)]; + std::vector dot_urls; + dot_urls.reserve(dot_graphs.size()); + for (const std::string& dot : dot_graphs) { + TF_ASSIGN_OR_RETURN(std::string url, (*url_renderer)(dot)); + dot_urls.push_back(url); + } + + return absl::StrReplaceAll( + R"( + + + Fusion Explorer: $TITLE + +

+

+ + +

+

+ Use j/k for keyboard navigation. +

+ + )", + {{"$URLS", absl::StrJoin(dot_urls, ", ", + [&](std::string* out, const std::string& url) { + absl::StrAppend(out, "\"", url, "\""); + })}, + {"$TITLE", + absl::StrCat(computation.parent()->name(), "_", computation.name())}}); +} + +// Precondition: (url_renderer != nullptr || (format != kUrl +// && format != kFusionVisualization)). // // (We specify this as a precondition rather than checking it in here and // returning an error because we want to fail quickly when there's no URL // renderer available, and this function runs only after we've done all the work // of producing dot for the graph.) -StatusOr WrapDotInFormat(absl::string_view dot, +StatusOr WrapDotInFormat(const HloComputation& computation, + absl::string_view dot, RenderedGraphFormat format) TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) { switch (format) { @@ -1595,6 +1697,8 @@ StatusOr WrapDotInFormat(absl::string_view dot, return WrapDotInHtml(dot); case RenderedGraphFormat::kDot: return string(dot); + case RenderedGraphFormat::kFusionVisualization: + return WrapFusionExplorer(computation); } } @@ -1613,6 +1717,25 @@ void RegisterGraphToURLRenderer( std::move(renderer)); } +Status RegisterFusionState(const HloComputation& computation, + absl::string_view label) { + tensorflow::mutex_lock lock(fusion_visualizer_state_mu); + TF_ASSIGN_OR_RETURN( + string dot_graph, + RenderGraph(computation, + absl::StrCat(computation.parent()->name(), ", ", + computation.name(), ", ", label), + /*debug_options=*/{}, xla::RenderedGraphFormat::kDot, + /*hlo_execution_profile=*/nullptr, + /*hlo_render_options=*/{})); + std::vector& fusion_states = + fusion_visualizer_state[FusionVisualizerStateKey(computation)]; + if (fusion_states.empty() || fusion_states.back() != dot_graph) { + fusion_states.push_back(dot_graph); + } + return Status::OK(); +} + StatusOr RenderGraph(const HloComputation& computation, absl::string_view label, const DebugOptions& debug_options, @@ -1628,7 +1751,7 @@ StatusOr RenderGraph(const HloComputation& computation, HloDotDumper(&computation, label, debug_options, hlo_render_options, hlo_execution_profile, NodeFilter()) .Dump(); - return WrapDotInFormat(rendered_dot, format); + return WrapDotInFormat(computation, rendered_dot, format); } StatusOr RenderNeighborhoodAround( @@ -1649,7 +1772,7 @@ StatusOr RenderNeighborhoodAround( hlo_render_options, /*profile=*/nullptr, MakeNodeRadiusAroundFilter(&node, radius, boundary)) .Dump(); - return WrapDotInFormat(rendered_dot, format); + return WrapDotInFormat(*node.parent(), rendered_dot, format); } StatusOr RenderAllPathsFromTo(const HloInstruction& from, @@ -1680,7 +1803,7 @@ StatusOr RenderAllPathsFromTo(const HloInstruction& from, HloDotDumper(from.parent(), label, debug_options, hlo_render_options, /*profile=*/nullptr, filter) .Dump(); - return WrapDotInFormat(rendered_dot, format); + return WrapDotInFormat(*from.parent(), rendered_dot, format); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 528de77e4e6..21907c6c5da 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -27,7 +27,7 @@ limitations under the License. // human-readable graphical format. // // Fundamentally all graphs are rendered using the DOT language, but they can be -// packaged three different ways: +// packaged four different ways: // // - as a raw DOT file, which can be rendered using `graphviz`. // @@ -36,7 +36,9 @@ limitations under the License. // // - as a URL hosted somewhere which somehow embeds the DOT file. // -// This last option is not implemented by default, but you can add a plugin to +// - as an HTML page showing the fusion progress. +// +// Two last options are not implemented by default, but you can add a plugin to // implement it via RegisterGraphToURLRenderer. // // TODO(jlebar): Rename this file to hlo_graph_renderer. @@ -48,6 +50,7 @@ enum class RenderedGraphFormat { kDot, kHtml, kUrl, + kFusionVisualization, }; struct HloRenderOptions { @@ -92,6 +95,11 @@ StatusOr RenderAllPathsFromTo(const HloInstruction& from, RenderedGraphFormat format, HloRenderOptions hlo_render_options = {}); +// Registers the fusion state of the graph for future visualization using +// the kFusionVisulization render format. +Status RegisterFusionState(const HloComputation& computation, + absl::string_view label); + // Registers a function which implements RenderedGraphFormat::kUrl. // // The input to the function is dot, and the output should be a URL or an error. diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index faa7e8e107e..1c27f81635f 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/fusion_queue.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" @@ -577,6 +578,12 @@ StatusOr InstructionFusion::Run(HloModule* module) { } break; } + + if (module->config().debug_options().xla_dump_fusion_visualization()) { + TF_RETURN_IF_ERROR(RegisterFusionState( + *computation, + absl::StrCat("InstructionFusion, may_duplicate=", may_duplicate_))); + } } if (config_collection_mode_ != FusionConfigCollection::kOff) { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 2d86b6efac8..f6f5a72be5c 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -250,6 +250,9 @@ message DebugOptions { // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) bool xla_dump_hlo_as_html = 116; + // Dump the visualization of the fusion progress. + bool xla_dump_fusion_visualization = 149; + // If true, every time an HLO module is run, we will dump an HloSnapshot // (essentially, a serialized module plus its inputs) to the --xla_dump_to // directory. @@ -311,7 +314,7 @@ message DebugOptions { // Compilation errors out if these ops are encountered. bool xla_gpu_deterministic_ops = 148; - // Next id: 149 + // Next id: 150 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.