From e6e857f2578670823b0adb01bac711dcf1b5369f Mon Sep 17 00:00:00 2001 From: Xiao Yu Date: Tue, 19 Mar 2019 12:20:30 -0700 Subject: [PATCH] Allow setting thread name when starting a new thread in posix and windowes. Thread name is important for profiling. It tells user what kind of task is scheduled on this thread. PiperOrigin-RevId: 239243923 --- tensorflow/core/platform/env_test.cc | 20 +++++++++++++++ tensorflow/core/platform/posix/env.cc | 33 +++++++++++++++++++++--- tensorflow/core/platform/windows/env.cc | 34 ++++++++++++++++++++++--- 3 files changed, 80 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index ea1f1234247..593727e850d 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -51,6 +51,11 @@ GraphDef CreateTestProto() { return g; } +static void ExpectHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(str_util::StrContains(s, expected)) + << "'" << s << "' does not contain '" << expected << "'"; +} + } // namespace string BaseDir() { return io::JoinPath(testing::TmpDir(), "base_dir"); } @@ -408,4 +413,19 @@ TEST_F(DefaultEnvTest, GetThreadInformation) { #endif } +TEST_F(DefaultEnvTest, GetChildThreadInformation) { + Env* env = Env::Default(); + Thread* child_thread = env->StartThread({}, "tf_child_thread", [env]() { + // TODO(fishx): Turn on this test for Apple. +#if !defined(__APPLE__) + EXPECT_NE(env->GetCurrentThreadId(), 0); +#endif + string thread_name; + bool res = env->GetCurrentThreadName(&thread_name); + EXPECT_TRUE(res); + ExpectHasSubstr(thread_name, "tf_child_thread"); + }); + delete child_thread; +} + } // namespace tensorflow diff --git a/tensorflow/core/platform/posix/env.cc b/tensorflow/core/platform/posix/env.cc index f2dff5a9b64..2700a269c4d 100644 --- a/tensorflow/core/platform/posix/env.cc +++ b/tensorflow/core/platform/posix/env.cc @@ -32,19 +32,37 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/load_library.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/posix/posix_file_system.h" namespace tensorflow { namespace { +mutex name_mutex(tensorflow::LINKER_INITIALIZED); + +std::map& GetThreadNameRegistry() + EXCLUSIVE_LOCKS_REQUIRED(name_mutex) { + static auto* thread_name_registry = new std::map(); + return *thread_name_registry; +} + class StdThread : public Thread { public: - // name and thread_options are both ignored. + // thread_options is ignored. StdThread(const ThreadOptions& thread_options, const string& name, std::function fn) - : thread_(fn) {} - ~StdThread() override { thread_.join(); } + : thread_(fn) { + mutex_lock l(name_mutex); + GetThreadNameRegistry().emplace(thread_.get_id(), name); + } + + ~StdThread() override { + std::thread::id thread_id = thread_.get_id(); + thread_.join(); + mutex_lock l(name_mutex); + GetThreadNameRegistry().erase(thread_id); + } private: std::thread thread_; @@ -102,6 +120,15 @@ class PosixEnv : public Env { } bool GetCurrentThreadName(string* name) override { + { + mutex_lock l(name_mutex); + auto thread_name = + GetThreadNameRegistry().find(std::this_thread::get_id()); + if (thread_name != GetThreadNameRegistry().end()) { + *name = thread_name->second; + return true; + } + } #if defined(__ANDROID__) || defined(__EMSCRIPTEN__) return false; #else diff --git a/tensorflow/core/platform/windows/env.cc b/tensorflow/core/platform/windows/env.cc index e0e3dda7055..fedbd674d5f 100644 --- a/tensorflow/core/platform/windows/env.cc +++ b/tensorflow/core/platform/windows/env.cc @@ -40,13 +40,30 @@ namespace tensorflow { namespace { +mutex name_mutex(tensorflow::LINKER_INITIALIZED); + +std::map& GetThreadNameRegistry() + EXCLUSIVE_LOCKS_REQUIRED(name_mutex) { + static auto* thread_name_registry = new std::map(); + return *thread_name_registry; +} + class StdThread : public Thread { public: - // name and thread_options are both ignored. + // thread_options is ignored. StdThread(const ThreadOptions& thread_options, const string& name, std::function fn) - : thread_(fn) {} - ~StdThread() { thread_.join(); } + : thread_(fn) { + mutex_lock l(name_mutex); + GetThreadNameRegistry().emplace(thread_.get_id(), name); + } + + ~StdThread() override { + std::thread::id thread_id = thread_.get_id(); + thread_.join(); + mutex_lock l(name_mutex); + GetThreadNameRegistry().erase(thread_id); + } private: std::thread thread_; @@ -88,7 +105,16 @@ class WindowsEnv : public Env { return static_cast(::GetCurrentThreadId()); } - bool GetCurrentThreadName(string* name) override { return false; } + bool GetCurrentThreadName(string* name) override { + mutex_lock l(name_mutex); + auto thread_name = GetThreadNameRegistry().find(std::this_thread::get_id()); + if (thread_name != GetThreadNameRegistry().end()) { + *name = thread_name->second; + return true; + } else { + return false; + } + } static VOID CALLBACK SchedClosureCallback(PTP_CALLBACK_INSTANCE Instance, PVOID Context, PTP_WORK Work) {