Allow setting thread name when starting a new thread in posix and windowes. Thread name is important for profiling. It tells user what kind of task is scheduled on this thread.
PiperOrigin-RevId: 239243923
This commit is contained in:
parent
eb1d8f7954
commit
e6e857f257
@ -51,6 +51,11 @@ GraphDef CreateTestProto() {
|
|||||||
return g;
|
return g;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
||||||
|
EXPECT_TRUE(str_util::StrContains(s, expected))
|
||||||
|
<< "'" << s << "' does not contain '" << expected << "'";
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
string BaseDir() { return io::JoinPath(testing::TmpDir(), "base_dir"); }
|
string BaseDir() { return io::JoinPath(testing::TmpDir(), "base_dir"); }
|
||||||
@ -408,4 +413,19 @@ TEST_F(DefaultEnvTest, GetThreadInformation) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DefaultEnvTest, GetChildThreadInformation) {
|
||||||
|
Env* env = Env::Default();
|
||||||
|
Thread* child_thread = env->StartThread({}, "tf_child_thread", [env]() {
|
||||||
|
// TODO(fishx): Turn on this test for Apple.
|
||||||
|
#if !defined(__APPLE__)
|
||||||
|
EXPECT_NE(env->GetCurrentThreadId(), 0);
|
||||||
|
#endif
|
||||||
|
string thread_name;
|
||||||
|
bool res = env->GetCurrentThreadName(&thread_name);
|
||||||
|
EXPECT_TRUE(res);
|
||||||
|
ExpectHasSubstr(thread_name, "tf_child_thread");
|
||||||
|
});
|
||||||
|
delete child_thread;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -32,19 +32,37 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/load_library.h"
|
#include "tensorflow/core/platform/load_library.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/posix/posix_file_system.h"
|
#include "tensorflow/core/platform/posix/posix_file_system.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
mutex name_mutex(tensorflow::LINKER_INITIALIZED);
|
||||||
|
|
||||||
|
std::map<std::thread::id, string>& GetThreadNameRegistry()
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(name_mutex) {
|
||||||
|
static auto* thread_name_registry = new std::map<std::thread::id, string>();
|
||||||
|
return *thread_name_registry;
|
||||||
|
}
|
||||||
|
|
||||||
class StdThread : public Thread {
|
class StdThread : public Thread {
|
||||||
public:
|
public:
|
||||||
// name and thread_options are both ignored.
|
// thread_options is ignored.
|
||||||
StdThread(const ThreadOptions& thread_options, const string& name,
|
StdThread(const ThreadOptions& thread_options, const string& name,
|
||||||
std::function<void()> fn)
|
std::function<void()> fn)
|
||||||
: thread_(fn) {}
|
: thread_(fn) {
|
||||||
~StdThread() override { thread_.join(); }
|
mutex_lock l(name_mutex);
|
||||||
|
GetThreadNameRegistry().emplace(thread_.get_id(), name);
|
||||||
|
}
|
||||||
|
|
||||||
|
~StdThread() override {
|
||||||
|
std::thread::id thread_id = thread_.get_id();
|
||||||
|
thread_.join();
|
||||||
|
mutex_lock l(name_mutex);
|
||||||
|
GetThreadNameRegistry().erase(thread_id);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::thread thread_;
|
std::thread thread_;
|
||||||
@ -102,6 +120,15 @@ class PosixEnv : public Env {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool GetCurrentThreadName(string* name) override {
|
bool GetCurrentThreadName(string* name) override {
|
||||||
|
{
|
||||||
|
mutex_lock l(name_mutex);
|
||||||
|
auto thread_name =
|
||||||
|
GetThreadNameRegistry().find(std::this_thread::get_id());
|
||||||
|
if (thread_name != GetThreadNameRegistry().end()) {
|
||||||
|
*name = thread_name->second;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||||
return false;
|
return false;
|
||||||
#else
|
#else
|
||||||
|
@ -40,13 +40,30 @@ namespace tensorflow {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
mutex name_mutex(tensorflow::LINKER_INITIALIZED);
|
||||||
|
|
||||||
|
std::map<std::thread::id, string>& GetThreadNameRegistry()
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(name_mutex) {
|
||||||
|
static auto* thread_name_registry = new std::map<std::thread::id, string>();
|
||||||
|
return *thread_name_registry;
|
||||||
|
}
|
||||||
|
|
||||||
class StdThread : public Thread {
|
class StdThread : public Thread {
|
||||||
public:
|
public:
|
||||||
// name and thread_options are both ignored.
|
// thread_options is ignored.
|
||||||
StdThread(const ThreadOptions& thread_options, const string& name,
|
StdThread(const ThreadOptions& thread_options, const string& name,
|
||||||
std::function<void()> fn)
|
std::function<void()> fn)
|
||||||
: thread_(fn) {}
|
: thread_(fn) {
|
||||||
~StdThread() { thread_.join(); }
|
mutex_lock l(name_mutex);
|
||||||
|
GetThreadNameRegistry().emplace(thread_.get_id(), name);
|
||||||
|
}
|
||||||
|
|
||||||
|
~StdThread() override {
|
||||||
|
std::thread::id thread_id = thread_.get_id();
|
||||||
|
thread_.join();
|
||||||
|
mutex_lock l(name_mutex);
|
||||||
|
GetThreadNameRegistry().erase(thread_id);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::thread thread_;
|
std::thread thread_;
|
||||||
@ -88,7 +105,16 @@ class WindowsEnv : public Env {
|
|||||||
return static_cast<int32>(::GetCurrentThreadId());
|
return static_cast<int32>(::GetCurrentThreadId());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GetCurrentThreadName(string* name) override { return false; }
|
bool GetCurrentThreadName(string* name) override {
|
||||||
|
mutex_lock l(name_mutex);
|
||||||
|
auto thread_name = GetThreadNameRegistry().find(std::this_thread::get_id());
|
||||||
|
if (thread_name != GetThreadNameRegistry().end()) {
|
||||||
|
*name = thread_name->second;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static VOID CALLBACK SchedClosureCallback(PTP_CALLBACK_INSTANCE Instance,
|
static VOID CALLBACK SchedClosureCallback(PTP_CALLBACK_INSTANCE Instance,
|
||||||
PVOID Context, PTP_WORK Work) {
|
PVOID Context, PTP_WORK Work) {
|
||||||
|
Loading…
Reference in New Issue
Block a user