Print HLO module name on slow compile

PiperOrigin-RevId: 346940007
Change-Id: Ifde06eed8af2ca1be52366ef32c1a916d330b523
This commit is contained in:
Sanjoy Das 2020-12-10 23:16:11 -08:00 committed by TensorFlower Gardener
parent 20423e72df
commit b2ad941e21
5 changed files with 20 additions and 6 deletions

View File

@ -5170,6 +5170,7 @@ cc_library(
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],

View File

@ -637,7 +637,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
VLOG(1) << "Compiling: " << module->name();
XLA_SCOPED_LOGGING_TIMER(
absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
auto slow_compile_alarm = SlowCompilationAlarm();
std::string slow_compilation_msg =
absl::StrCat("Compiling module ", module->name());
auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
TF_RET_CHECK(stream_exec != nullptr);
absl::call_once(llvm_command_line_options_initialized,

View File

@ -783,7 +783,9 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
const CompileOptions& options) {
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
auto slow_compile_alarm = SlowCompilationAlarm();
std::string slow_compilation_msg =
absl::StrCat("Compiling module ", module->name());
auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
TF_RET_CHECK(stream_exec != nullptr);

View File

@ -106,12 +106,19 @@ SlowOperationAlarm::SlowOperationAlarm(absl::Duration timeout, string msg,
SlowOperationAlarm::~SlowOperationAlarm() { UnscheduleAlarm(this); }
std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm() {
std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm(
absl::string_view msg) {
// Pass a counter to these alarms so they only log once every power-of-two
// occurrences.
static auto* counter = new std::atomic<int64>(0);
const char* separator = "\n********************************";
std::string msg_suffix;
if (!msg.empty()) {
msg_suffix = absl::StrCat("\n", msg);
}
#if NDEBUG
return absl::make_unique<SlowOperationAlarm>(
absl::Duration(absl::Minutes(2)),
@ -119,7 +126,7 @@ std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm() {
separator,
"\nVery slow compile? If you want to file a bug, run with envvar "
"XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.",
separator),
msg_suffix, separator),
counter);
#else
return absl::make_unique<SlowOperationAlarm>(
@ -128,7 +135,7 @@ std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm() {
separator,
"\nSlow compile? XLA was built without compiler optimizations, "
"which can be slow. Try rebuilding with -c opt.",
separator),
msg_suffix, separator),
counter);
#endif
}

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <tuple>
#include "absl/base/attributes.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "tensorflow/compiler/xla/types.h"
@ -64,7 +65,8 @@ class SlowOperationAlarm {
// In opt builds, recommends filing a bug.
//
// This is throttled to once-every-power-of-two occurrences, globally.
ABSL_MUST_USE_RESULT std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm();
ABSL_MUST_USE_RESULT std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm(
absl::string_view msg = "");
} // namespace xla