Initialize the TF logger asynchronously
This prevents a slow-to-construct logger from affecting latency. PiperOrigin-RevId: 256190005
This commit is contained in:
parent
32499e2a2b
commit
3c48e25f06
@ -29,9 +29,15 @@ class XlaActivityLoggingListener final : public XlaActivityListener {
|
||||
VLOG(3) << "Logging XlaAutoClusteringActivity disabled";
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(2) << "Logging XlaAutoClusteringActivity";
|
||||
VLOG(3) << auto_clustering_activity.DebugString();
|
||||
Logger::Singleton()->LogProto(auto_clustering_activity);
|
||||
|
||||
if (Logger* logger = Logger::GetSingletonAsync()) {
|
||||
VLOG(2) << "Logging XlaAutoClusteringActivity";
|
||||
VLOG(3) << auto_clustering_activity.DebugString();
|
||||
logger->LogProto(auto_clustering_activity);
|
||||
} else {
|
||||
VLOG(2) << "Not logging: logger not ready yet.";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -41,9 +47,15 @@ class XlaActivityLoggingListener final : public XlaActivityListener {
|
||||
VLOG(3) << "Logging XlaJitCompilationActivity disabled";
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(2) << "Logging XlaJitCompilationActivity";
|
||||
VLOG(3) << jit_compilation_activity.DebugString();
|
||||
Logger::Singleton()->LogProto(jit_compilation_activity);
|
||||
|
||||
if (Logger* logger = Logger::GetSingletonAsync()) {
|
||||
VLOG(2) << "Logging XlaJitCompilationActivity";
|
||||
VLOG(3) << jit_compilation_activity.DebugString();
|
||||
logger->LogProto(jit_compilation_activity);
|
||||
} else {
|
||||
VLOG(2) << "Not logging: logger not ready yet.";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -445,7 +445,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
// If we crash on checking failure, we are in a testing/benchmark mode, thus
|
||||
// omitting logging through the logger.
|
||||
if (!crash_on_checking_failure) {
|
||||
tensorflow::Logger::Singleton()->LogProto(log);
|
||||
tensorflow::Logger::GetSingleton()->LogProto(log);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -147,7 +147,7 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
||||
*log.add_results() = profile;
|
||||
}
|
||||
if (!crash_on_checking_failure) {
|
||||
tensorflow::Logger::Singleton()->LogProto(log);
|
||||
tensorflow::Logger::GetSingleton()->LogProto(log);
|
||||
}
|
||||
|
||||
// Choose fastest correct GEMM, but allow for incorrect results (since the
|
||||
|
@ -564,7 +564,7 @@ void LogFusedConvForwardAutotuneResults(
|
||||
for (const auto& result : results) {
|
||||
*log.add_results() = result;
|
||||
}
|
||||
Logger::Singleton()->LogProto(log);
|
||||
Logger::GetSingleton()->LogProto(log);
|
||||
}
|
||||
|
||||
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
|
@ -498,7 +498,12 @@ cc_library(
|
||||
hdrs = ["platform/logger.h"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":lib_proto_parsing"],
|
||||
deps = [
|
||||
":lib",
|
||||
":lib_proto_parsing",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
@ -88,7 +88,7 @@ void LogConvAutotuneResults(se::dnn::ConvolutionKind kind,
|
||||
for (const auto& result : results) {
|
||||
*log.add_results() = result;
|
||||
}
|
||||
Logger::Singleton()->LogProto(log);
|
||||
Logger::GetSingleton()->LogProto(log);
|
||||
}
|
||||
|
||||
void LogFusedConvForwardAutotuneResults(
|
||||
@ -126,7 +126,7 @@ void LogFusedConvForwardAutotuneResults(
|
||||
for (const auto& result : results) {
|
||||
*log.add_results() = result;
|
||||
}
|
||||
Logger::Singleton()->LogProto(log);
|
||||
Logger::GetSingleton()->LogProto(log);
|
||||
}
|
||||
|
||||
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -34,4 +37,70 @@ Logger::FactoryFunc Logger::singleton_factory_ = []() -> Logger* {
|
||||
return new DefaultLogger();
|
||||
};
|
||||
|
||||
struct LoggerSingletonContainer {
|
||||
// Used to kick off the construction of a new thread that will asynchronously
|
||||
// construct a Logger.
|
||||
absl::once_flag start_initialization_thread_flag;
|
||||
|
||||
// The constructed logger, if there is one.
|
||||
Logger* logger;
|
||||
|
||||
// The initializing thread notifies `logger_initialized` after storing the
|
||||
// constructed logger to `logger`.
|
||||
absl::Notification logger_initialized;
|
||||
|
||||
// The thread used to construct the Logger instance asynchronously.
|
||||
std::unique_ptr<Thread> initialization_thread;
|
||||
|
||||
// Used to kick off the joining and destruction of `initialization_thread`.
|
||||
absl::once_flag delete_initialization_thread_flag;
|
||||
};
|
||||
|
||||
LoggerSingletonContainer* GetLoggerSingletonContainer() {
|
||||
static LoggerSingletonContainer* container = new LoggerSingletonContainer;
|
||||
return container;
|
||||
}
|
||||
|
||||
struct AsyncSingletonImpl {
|
||||
static void InitializationThreadFn() {
|
||||
LoggerSingletonContainer* container = GetLoggerSingletonContainer();
|
||||
container->logger = Logger::singleton_factory_();
|
||||
container->logger_initialized.Notify();
|
||||
}
|
||||
|
||||
static void StartInitializationThread(LoggerSingletonContainer* container) {
|
||||
Thread* thread =
|
||||
Env::Default()->StartThread(ThreadOptions{}, "logger-init-thread",
|
||||
AsyncSingletonImpl::InitializationThreadFn);
|
||||
container->initialization_thread.reset(thread);
|
||||
}
|
||||
};
|
||||
|
||||
/*static*/ Logger* Logger::GetSingleton() {
|
||||
// Call the async version to kick off the initialization thread if necessary.
|
||||
(void)Logger::GetSingletonAsync();
|
||||
|
||||
// And wait for the thread to finish.
|
||||
LoggerSingletonContainer* container = GetLoggerSingletonContainer();
|
||||
absl::call_once(container->delete_initialization_thread_flag,
|
||||
[container]() { container->initialization_thread.reset(); });
|
||||
|
||||
return container->logger;
|
||||
}
|
||||
|
||||
/*static*/ Logger* Logger::GetSingletonAsync() {
|
||||
LoggerSingletonContainer* container = GetLoggerSingletonContainer();
|
||||
absl::call_once(container->start_initialization_thread_flag,
|
||||
AsyncSingletonImpl::StartInitializationThread, container);
|
||||
|
||||
if (container->logger_initialized.HasBeenNotified()) {
|
||||
// Wait for the initializing thread to finish to reclaim resources.
|
||||
absl::call_once(
|
||||
container->delete_initialization_thread_flag,
|
||||
[container]() { container->initialization_thread.reset(); });
|
||||
return container->logger;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -38,10 +38,16 @@ class Logger {
|
||||
singleton_factory_ = factory;
|
||||
}
|
||||
|
||||
static Logger* Singleton() {
|
||||
static Logger* instance = singleton_factory_();
|
||||
return instance;
|
||||
}
|
||||
// Returns the per-process Logger instance, constructing synchronously it if
|
||||
// necessary.
|
||||
static Logger* GetSingleton();
|
||||
|
||||
// Like GetSingleton, except that this does not wait for the construction of
|
||||
// Logger to finish before returning.
|
||||
//
|
||||
// Returns the constructed instance of Logger if it has been constructed,
|
||||
// otherwise returns nullptr (if the logger is not ready yet).
|
||||
static Logger* GetSingletonAsync();
|
||||
|
||||
virtual ~Logger() = default;
|
||||
|
||||
@ -61,6 +67,8 @@ class Logger {
|
||||
virtual void DoFlush() = 0;
|
||||
|
||||
static FactoryFunc singleton_factory_;
|
||||
|
||||
friend struct AsyncSingletonImpl;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user