Allow setting thread name when starting a new thread in posix and windowes. Thread name is important for profiling. It tells user what kind of task is scheduled on this thread.
PiperOrigin-RevId: 239243923
This commit is contained in:
parent
eb1d8f7954
commit
e6e857f257
@ -51,6 +51,11 @@ GraphDef CreateTestProto() {
|
||||
return g;
|
||||
}
|
||||
|
||||
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
||||
EXPECT_TRUE(str_util::StrContains(s, expected))
|
||||
<< "'" << s << "' does not contain '" << expected << "'";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
string BaseDir() { return io::JoinPath(testing::TmpDir(), "base_dir"); }
|
||||
@ -408,4 +413,19 @@ TEST_F(DefaultEnvTest, GetThreadInformation) {
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(DefaultEnvTest, GetChildThreadInformation) {
|
||||
Env* env = Env::Default();
|
||||
Thread* child_thread = env->StartThread({}, "tf_child_thread", [env]() {
|
||||
// TODO(fishx): Turn on this test for Apple.
|
||||
#if !defined(__APPLE__)
|
||||
EXPECT_NE(env->GetCurrentThreadId(), 0);
|
||||
#endif
|
||||
string thread_name;
|
||||
bool res = env->GetCurrentThreadName(&thread_name);
|
||||
EXPECT_TRUE(res);
|
||||
ExpectHasSubstr(thread_name, "tf_child_thread");
|
||||
});
|
||||
delete child_thread;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -32,19 +32,37 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/load_library.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/posix/posix_file_system.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
mutex name_mutex(tensorflow::LINKER_INITIALIZED);
|
||||
|
||||
std::map<std::thread::id, string>& GetThreadNameRegistry()
|
||||
EXCLUSIVE_LOCKS_REQUIRED(name_mutex) {
|
||||
static auto* thread_name_registry = new std::map<std::thread::id, string>();
|
||||
return *thread_name_registry;
|
||||
}
|
||||
|
||||
class StdThread : public Thread {
|
||||
public:
|
||||
// name and thread_options are both ignored.
|
||||
// thread_options is ignored.
|
||||
StdThread(const ThreadOptions& thread_options, const string& name,
|
||||
std::function<void()> fn)
|
||||
: thread_(fn) {}
|
||||
~StdThread() override { thread_.join(); }
|
||||
: thread_(fn) {
|
||||
mutex_lock l(name_mutex);
|
||||
GetThreadNameRegistry().emplace(thread_.get_id(), name);
|
||||
}
|
||||
|
||||
~StdThread() override {
|
||||
std::thread::id thread_id = thread_.get_id();
|
||||
thread_.join();
|
||||
mutex_lock l(name_mutex);
|
||||
GetThreadNameRegistry().erase(thread_id);
|
||||
}
|
||||
|
||||
private:
|
||||
std::thread thread_;
|
||||
@ -102,6 +120,15 @@ class PosixEnv : public Env {
|
||||
}
|
||||
|
||||
bool GetCurrentThreadName(string* name) override {
|
||||
{
|
||||
mutex_lock l(name_mutex);
|
||||
auto thread_name =
|
||||
GetThreadNameRegistry().find(std::this_thread::get_id());
|
||||
if (thread_name != GetThreadNameRegistry().end()) {
|
||||
*name = thread_name->second;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
return false;
|
||||
#else
|
||||
|
@ -40,13 +40,30 @@ namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
mutex name_mutex(tensorflow::LINKER_INITIALIZED);
|
||||
|
||||
std::map<std::thread::id, string>& GetThreadNameRegistry()
|
||||
EXCLUSIVE_LOCKS_REQUIRED(name_mutex) {
|
||||
static auto* thread_name_registry = new std::map<std::thread::id, string>();
|
||||
return *thread_name_registry;
|
||||
}
|
||||
|
||||
class StdThread : public Thread {
|
||||
public:
|
||||
// name and thread_options are both ignored.
|
||||
// thread_options is ignored.
|
||||
StdThread(const ThreadOptions& thread_options, const string& name,
|
||||
std::function<void()> fn)
|
||||
: thread_(fn) {}
|
||||
~StdThread() { thread_.join(); }
|
||||
: thread_(fn) {
|
||||
mutex_lock l(name_mutex);
|
||||
GetThreadNameRegistry().emplace(thread_.get_id(), name);
|
||||
}
|
||||
|
||||
~StdThread() override {
|
||||
std::thread::id thread_id = thread_.get_id();
|
||||
thread_.join();
|
||||
mutex_lock l(name_mutex);
|
||||
GetThreadNameRegistry().erase(thread_id);
|
||||
}
|
||||
|
||||
private:
|
||||
std::thread thread_;
|
||||
@ -88,7 +105,16 @@ class WindowsEnv : public Env {
|
||||
return static_cast<int32>(::GetCurrentThreadId());
|
||||
}
|
||||
|
||||
bool GetCurrentThreadName(string* name) override { return false; }
|
||||
bool GetCurrentThreadName(string* name) override {
|
||||
mutex_lock l(name_mutex);
|
||||
auto thread_name = GetThreadNameRegistry().find(std::this_thread::get_id());
|
||||
if (thread_name != GetThreadNameRegistry().end()) {
|
||||
*name = thread_name->second;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static VOID CALLBACK SchedClosureCallback(PTP_CALLBACK_INSTANCE Instance,
|
||||
PVOID Context, PTP_WORK Work) {
|
||||
|
Loading…
Reference in New Issue
Block a user