From b2ad941e21a347710014893423c07d34e05f0f16 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Thu, 10 Dec 2020 23:16:11 -0800 Subject: [PATCH] Print HLO module name on slow compile PiperOrigin-RevId: 346940007 Change-Id: Ifde06eed8af2ca1be52366ef32c1a916d330b523 --- tensorflow/compiler/xla/service/BUILD | 1 + tensorflow/compiler/xla/service/cpu/cpu_compiler.cc | 4 +++- tensorflow/compiler/xla/service/gpu/gpu_compiler.cc | 4 +++- .../compiler/xla/service/slow_operation_alarm.cc | 13 ++++++++++--- .../compiler/xla/service/slow_operation_alarm.h | 4 +++- 5 files changed, 20 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 482b5f37579..36982700d5e 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -5170,6 +5170,7 @@ cc_library( "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", ], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index ee1f573c867..5bd2d13688b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -637,7 +637,9 @@ StatusOr> CpuCompiler::RunBackend( VLOG(1) << "Compiling: " << module->name(); XLA_SCOPED_LOGGING_TIMER( absl::StrFormat("Compiling [%s] for CPU using JIT", module->name())); - auto slow_compile_alarm = SlowCompilationAlarm(); + std::string slow_compilation_msg = + absl::StrCat("Compiling module ", module->name()); + auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg); TF_RET_CHECK(stream_exec != nullptr); absl::call_once(llvm_command_line_options_initialized, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 4073b4f458b..6f36e02ad2c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -783,7 +783,9 @@ StatusOr> GpuCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, const CompileOptions& options) { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend"); - auto slow_compile_alarm = SlowCompilationAlarm(); + std::string slow_compilation_msg = + absl::StrCat("Compiling module ", module->name()); + auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg); TF_RET_CHECK(stream_exec != nullptr); diff --git a/tensorflow/compiler/xla/service/slow_operation_alarm.cc b/tensorflow/compiler/xla/service/slow_operation_alarm.cc index 2ce66b25daa..13f6ac34a52 100644 --- a/tensorflow/compiler/xla/service/slow_operation_alarm.cc +++ b/tensorflow/compiler/xla/service/slow_operation_alarm.cc @@ -106,12 +106,19 @@ SlowOperationAlarm::SlowOperationAlarm(absl::Duration timeout, string msg, SlowOperationAlarm::~SlowOperationAlarm() { UnscheduleAlarm(this); } -std::unique_ptr SlowCompilationAlarm() { +std::unique_ptr SlowCompilationAlarm( + absl::string_view msg) { // Pass a counter to these alarms so they only log once every power-of-two // occurrences. static auto* counter = new std::atomic(0); const char* separator = "\n********************************"; + + std::string msg_suffix; + if (!msg.empty()) { + msg_suffix = absl::StrCat("\n", msg); + } + #if NDEBUG return absl::make_unique( absl::Duration(absl::Minutes(2)), @@ -119,7 +126,7 @@ std::unique_ptr SlowCompilationAlarm() { separator, "\nVery slow compile? If you want to file a bug, run with envvar " "XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.", - separator), + msg_suffix, separator), counter); #else return absl::make_unique( @@ -128,7 +135,7 @@ std::unique_ptr SlowCompilationAlarm() { separator, "\nSlow compile? XLA was built without compiler optimizations, " "which can be slow. Try rebuilding with -c opt.", - separator), + msg_suffix, separator), counter); #endif } diff --git a/tensorflow/compiler/xla/service/slow_operation_alarm.h b/tensorflow/compiler/xla/service/slow_operation_alarm.h index 20099bb875f..bd845912e72 100644 --- a/tensorflow/compiler/xla/service/slow_operation_alarm.h +++ b/tensorflow/compiler/xla/service/slow_operation_alarm.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/base/attributes.h" +#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "tensorflow/compiler/xla/types.h" @@ -64,7 +65,8 @@ class SlowOperationAlarm { // In opt builds, recommends filing a bug. // // This is throttled to once-every-power-of-two occurrences, globally. -ABSL_MUST_USE_RESULT std::unique_ptr SlowCompilationAlarm(); +ABSL_MUST_USE_RESULT std::unique_ptr SlowCompilationAlarm( + absl::string_view msg = ""); } // namespace xla