Allow setting thread name when starting a new thread in posix and windowes. Thread name is important for profiling. It tells user what kind of task is scheduled on this thread.

PiperOrigin-RevId: 239243923
This commit is contained in:
Xiao Yu 2019-03-19 12:20:30 -07:00 committed by TensorFlower Gardener
parent eb1d8f7954
commit e6e857f257
3 changed files with 80 additions and 7 deletions

View File

@ -51,6 +51,11 @@ GraphDef CreateTestProto() {
return g;
}
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
EXPECT_TRUE(str_util::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
} // namespace
string BaseDir() { return io::JoinPath(testing::TmpDir(), "base_dir"); }
@ -408,4 +413,19 @@ TEST_F(DefaultEnvTest, GetThreadInformation) {
#endif
}
TEST_F(DefaultEnvTest, GetChildThreadInformation) {
Env* env = Env::Default();
Thread* child_thread = env->StartThread({}, "tf_child_thread", [env]() {
// TODO(fishx): Turn on this test for Apple.
#if !defined(__APPLE__)
EXPECT_NE(env->GetCurrentThreadId(), 0);
#endif
string thread_name;
bool res = env->GetCurrentThreadName(&thread_name);
EXPECT_TRUE(res);
ExpectHasSubstr(thread_name, "tf_child_thread");
});
delete child_thread;
}
} // namespace tensorflow

View File

@ -32,19 +32,37 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/load_library.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/posix/posix_file_system.h"
namespace tensorflow {
namespace {
mutex name_mutex(tensorflow::LINKER_INITIALIZED);
std::map<std::thread::id, string>& GetThreadNameRegistry()
EXCLUSIVE_LOCKS_REQUIRED(name_mutex) {
static auto* thread_name_registry = new std::map<std::thread::id, string>();
return *thread_name_registry;
}
class StdThread : public Thread {
public:
// name and thread_options are both ignored.
// thread_options is ignored.
StdThread(const ThreadOptions& thread_options, const string& name,
std::function<void()> fn)
: thread_(fn) {}
~StdThread() override { thread_.join(); }
: thread_(fn) {
mutex_lock l(name_mutex);
GetThreadNameRegistry().emplace(thread_.get_id(), name);
}
~StdThread() override {
std::thread::id thread_id = thread_.get_id();
thread_.join();
mutex_lock l(name_mutex);
GetThreadNameRegistry().erase(thread_id);
}
private:
std::thread thread_;
@ -102,6 +120,15 @@ class PosixEnv : public Env {
}
bool GetCurrentThreadName(string* name) override {
{
mutex_lock l(name_mutex);
auto thread_name =
GetThreadNameRegistry().find(std::this_thread::get_id());
if (thread_name != GetThreadNameRegistry().end()) {
*name = thread_name->second;
return true;
}
}
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
return false;
#else

View File

@ -40,13 +40,30 @@ namespace tensorflow {
namespace {
mutex name_mutex(tensorflow::LINKER_INITIALIZED);
std::map<std::thread::id, string>& GetThreadNameRegistry()
EXCLUSIVE_LOCKS_REQUIRED(name_mutex) {
static auto* thread_name_registry = new std::map<std::thread::id, string>();
return *thread_name_registry;
}
class StdThread : public Thread {
public:
// name and thread_options are both ignored.
// thread_options is ignored.
StdThread(const ThreadOptions& thread_options, const string& name,
std::function<void()> fn)
: thread_(fn) {}
~StdThread() { thread_.join(); }
: thread_(fn) {
mutex_lock l(name_mutex);
GetThreadNameRegistry().emplace(thread_.get_id(), name);
}
~StdThread() override {
std::thread::id thread_id = thread_.get_id();
thread_.join();
mutex_lock l(name_mutex);
GetThreadNameRegistry().erase(thread_id);
}
private:
std::thread thread_;
@ -88,7 +105,16 @@ class WindowsEnv : public Env {
return static_cast<int32>(::GetCurrentThreadId());
}
bool GetCurrentThreadName(string* name) override { return false; }
bool GetCurrentThreadName(string* name) override {
mutex_lock l(name_mutex);
auto thread_name = GetThreadNameRegistry().find(std::this_thread::get_id());
if (thread_name != GetThreadNameRegistry().end()) {
*name = thread_name->second;
return true;
} else {
return false;
}
}
static VOID CALLBACK SchedClosureCallback(PTP_CALLBACK_INSTANCE Instance,
PVOID Context, PTP_WORK Work) {