Introduce sub thread pools to run handler thread pool.
PiperOrigin-RevId: 279237293 Change-Id: I8f1945105d27d1ed06eee8bb914bfaa06fec6c1f
This commit is contained in:
parent
af019188ad
commit
bfce5bae88
@ -4468,11 +4468,18 @@ tf_cc_test(
|
||||
srcs = ["framework/run_handler_test.cc"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
deps = [
|
||||
":core_cpu",
|
||||
":direct_session_internal",
|
||||
":framework_internal",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
":protos_all_cc",
|
||||
":tensor_testutil",
|
||||
":test",
|
||||
":test_main",
|
||||
":testlib",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:matmul_op",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
|
@ -98,6 +98,55 @@ class RunHandlerEnvironment {
|
||||
typedef typename RunHandlerEnvironment::Task Task;
|
||||
typedef Eigen::RunQueue<Task, 1024> Queue;
|
||||
|
||||
// To reduce cache misses, we use a doubly-linked list of Waiter structs and
|
||||
// queue them in LIFO order rather than the FIFO order used by a single
|
||||
// condition variable.
|
||||
struct Waiter {
|
||||
Waiter() {
|
||||
next = this;
|
||||
prev = this;
|
||||
}
|
||||
condition_variable cv;
|
||||
mutex mu;
|
||||
Waiter* next;
|
||||
Waiter* prev;
|
||||
};
|
||||
|
||||
void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
|
||||
int max_sleep_micros) {
|
||||
{
|
||||
mutex_lock l(*mutex);
|
||||
CHECK_EQ(waiter->next, waiter); // Crash OK.
|
||||
CHECK_EQ(waiter->prev, waiter); // Crash OK.
|
||||
|
||||
// Add waiter to the LIFO queue
|
||||
waiter->prev = queue_head;
|
||||
waiter->next = queue_head->next;
|
||||
waiter->next->prev = waiter;
|
||||
waiter->prev->next = waiter;
|
||||
}
|
||||
{
|
||||
mutex_lock l(waiter->mu);
|
||||
// Wait on the condition variable
|
||||
waiter->cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
|
||||
}
|
||||
|
||||
mutex_lock l(*mutex);
|
||||
// Remove waiter from the LIFO queue. Note even when a waiter wakes up due
|
||||
// to a notification we cannot conclude the waiter is not in the queue.
|
||||
// This is due to the fact that a thread preempted right before notifying
|
||||
// may resume after a waiter got re-added.
|
||||
if (waiter->next != waiter) {
|
||||
CHECK(waiter->prev != waiter); // Crash OK.
|
||||
waiter->next->prev = waiter->prev;
|
||||
waiter->prev->next = waiter->next;
|
||||
waiter->next = waiter;
|
||||
waiter->prev = waiter;
|
||||
} else {
|
||||
CHECK_EQ(waiter->prev, waiter); // Crash OK.
|
||||
}
|
||||
}
|
||||
|
||||
class ThreadWorkSource {
|
||||
public:
|
||||
ThreadWorkSource()
|
||||
@ -155,11 +204,32 @@ class ThreadWorkSource {
|
||||
if (max_rank_to_wakeup > 0 &&
|
||||
rank_.load(std::memory_order_relaxed) <= max_rank_to_wakeup) {
|
||||
Waiter* w = nullptr;
|
||||
bool use_sub_thread_pool = ParamFromEnvBoolWithDefault(
|
||||
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
|
||||
|
||||
Waiter* waiter_queue;
|
||||
mutex* waiter_queue_mu;
|
||||
if (use_sub_thread_pool) {
|
||||
// When we use multiple sub thread pools, free threads wait on sub
|
||||
// thread pool waiting queues. Wake up threads from sub thread waiting
|
||||
// queues.
|
||||
// The waiting queues are defined at RunHandlerPool.
|
||||
// Get the waiter_queue and coresponding mutex. Note, the thread work
|
||||
// source may change afterwards if a new request comes or an old request
|
||||
// finishes.
|
||||
tf_shared_lock lock(run_handler_waiter_mu_);
|
||||
waiter_queue = sub_thread_pool_waiter_;
|
||||
waiter_queue_mu = sub_thread_pool_waiter_mu_;
|
||||
} else {
|
||||
waiter_queue = &queue_waiters_;
|
||||
waiter_queue_mu = &waiters_mu_;
|
||||
}
|
||||
|
||||
{
|
||||
mutex_lock l(waiters_mu_);
|
||||
if (queue_waiters_.next != &queue_waiters_) {
|
||||
mutex_lock l(*waiter_queue_mu);
|
||||
if (waiter_queue->next != waiter_queue) {
|
||||
// Remove waiter from the LIFO queue
|
||||
w = queue_waiters_.next;
|
||||
w = waiter_queue->next;
|
||||
|
||||
CHECK(w->prev != w);
|
||||
CHECK(w->next != w);
|
||||
@ -187,43 +257,25 @@ class ThreadWorkSource {
|
||||
|
||||
Task PopBlockingTask() { return blocking_work_queue_.PopBack(); }
|
||||
|
||||
Task PopNonBlockingTask(int index) {
|
||||
return non_blocking_work_queues_[index]->queue.PopBack();
|
||||
Task PopNonBlockingTask(int start_index, bool search_from_all_queue) {
|
||||
Task t;
|
||||
unsigned sharding_factor = NonBlockingWorkShardingFactor();
|
||||
for (unsigned j = 0; j < sharding_factor; ++j) {
|
||||
t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
|
||||
->queue.PopBack();
|
||||
if (t.f) {
|
||||
return t;
|
||||
}
|
||||
if (!search_from_all_queue) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
void WaitForWork(int max_sleep_micros) {
|
||||
thread_local Waiter waiter;
|
||||
{
|
||||
mutex_lock l(waiters_mu_);
|
||||
CHECK_EQ(waiter.next, &waiter);
|
||||
CHECK_EQ(waiter.prev, &waiter);
|
||||
|
||||
// Add waiter to the LIFO queue
|
||||
waiter.prev = &queue_waiters_;
|
||||
waiter.next = queue_waiters_.next;
|
||||
waiter.next->prev = &waiter;
|
||||
waiter.prev->next = &waiter;
|
||||
}
|
||||
{
|
||||
mutex_lock l(waiter.mu);
|
||||
// Wait on the condition variable
|
||||
waiter.cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
|
||||
}
|
||||
|
||||
mutex_lock l(waiters_mu_);
|
||||
// Remove waiter from the LIFO queue. Note even when a waiter wakes up due
|
||||
// to a notification we cannot conclude the waiter is not in the queue.
|
||||
// This is due to the fact that a thread preempted right before notifying
|
||||
// may resume after a waiter got re-added.
|
||||
if (waiter.next != &waiter) {
|
||||
CHECK(waiter.prev != &waiter);
|
||||
waiter.next->prev = waiter.prev;
|
||||
waiter.prev->next = waiter.next;
|
||||
waiter.next = &waiter;
|
||||
waiter.prev = &waiter;
|
||||
} else {
|
||||
CHECK_EQ(waiter.prev, &waiter);
|
||||
}
|
||||
WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
|
||||
}
|
||||
|
||||
int TaskQueueSize(bool is_blocking) {
|
||||
@ -243,6 +295,12 @@ class ThreadWorkSource {
|
||||
void SetTracemeId(int64 value) { traceme_id_ = value; }
|
||||
void SetRank(int64 value) { rank_ = value; }
|
||||
|
||||
void SetWaiter(Waiter* waiter, mutex* mutex) {
|
||||
mutex_lock l(run_handler_waiter_mu_);
|
||||
sub_thread_pool_waiter_ = waiter;
|
||||
sub_thread_pool_waiter_mu_ = mutex;
|
||||
}
|
||||
|
||||
int64 GetInflightTaskCount(bool is_blocking) {
|
||||
std::atomic<int64>* counter =
|
||||
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
|
||||
@ -274,20 +332,6 @@ class ThreadWorkSource {
|
||||
}
|
||||
|
||||
private:
|
||||
// To reduce cache misses, we use a doubly-linked list of Waiter structs and
|
||||
// queue them in LIFO order rather than the FIFO order used by a single
|
||||
// condition variable.
|
||||
struct Waiter {
|
||||
Waiter() {
|
||||
next = this;
|
||||
prev = this;
|
||||
}
|
||||
condition_variable cv;
|
||||
mutex mu;
|
||||
Waiter* next;
|
||||
Waiter* prev;
|
||||
};
|
||||
|
||||
struct NonBlockingQueue {
|
||||
mutex queue_op_mu;
|
||||
char pad[128];
|
||||
@ -307,6 +351,10 @@ class ThreadWorkSource {
|
||||
Waiter queue_waiters_ GUARDED_BY(waiters_mu_);
|
||||
std::atomic<int64> traceme_id_;
|
||||
std::atomic<int64> rank_;
|
||||
|
||||
mutex run_handler_waiter_mu_;
|
||||
mutex* sub_thread_pool_waiter_mu_ GUARDED_BY(run_handler_waiter_mu_);
|
||||
Waiter* sub_thread_pool_waiter_ GUARDED_BY(run_handler_waiter_mu_);
|
||||
};
|
||||
|
||||
class RunHandlerThreadPool {
|
||||
@ -319,25 +367,33 @@ class RunHandlerThreadPool {
|
||||
|
||||
RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
|
||||
Env* env, const ThreadOptions& thread_options,
|
||||
const string& name)
|
||||
const string& name,
|
||||
Eigen::MaxSizeVector<mutex>* waiters_mu,
|
||||
Eigen::MaxSizeVector<Waiter>* queue_waiters)
|
||||
: num_threads_(num_blocking_threads + num_non_blocking_threads),
|
||||
num_blocking_threads_(num_blocking_threads),
|
||||
num_non_blocking_threads_(num_non_blocking_threads),
|
||||
thread_data_(num_threads_),
|
||||
env_(env, thread_options, name),
|
||||
name_(name) {
|
||||
name_(name),
|
||||
waiters_mu_(waiters_mu),
|
||||
queue_waiters_(queue_waiters),
|
||||
use_sub_thread_pool_(ParamFromEnvBoolWithDefault(
|
||||
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
|
||||
num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
|
||||
"TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
|
||||
std::vector<int>(
|
||||
{num_blocking_threads / 2,
|
||||
num_blocking_threads - num_blocking_threads / 2}))),
|
||||
sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
|
||||
"TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
|
||||
std::vector<double>({0, 0.4}))),
|
||||
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
|
||||
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
|
||||
std::vector<double>({0.4, 1}))) {
|
||||
VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
|
||||
<< num_blocking_threads_ << " blocking threads and "
|
||||
<< num_non_blocking_threads_ << " non-blocking threads.";
|
||||
cancelled_ = false;
|
||||
|
||||
thread_data_.resize(num_threads_);
|
||||
for (int i = 0; i < num_threads_; i++) {
|
||||
thread_data_[i].thread.reset(
|
||||
env_.CreateThread([this, i, num_blocking_threads]() {
|
||||
WorkerLoop(i, i < num_blocking_threads);
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
~RunHandlerThreadPool() {
|
||||
@ -353,6 +409,26 @@ class RunHandlerThreadPool {
|
||||
}
|
||||
}
|
||||
|
||||
void Start() {
|
||||
cancelled_ = false;
|
||||
thread_data_.resize(num_threads_);
|
||||
int num_blocking_threads = num_blocking_threads_;
|
||||
for (int i = 0; i < num_threads_; i++) {
|
||||
int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
|
||||
for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
|
||||
if (i < num_threads_in_sub_thread_pool_[j]) {
|
||||
sub_thread_pool_id = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
|
||||
thread_data_[i].thread.reset(
|
||||
env_.CreateThread([this, i, num_blocking_threads]() {
|
||||
WorkerLoop(i, i < num_blocking_threads);
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
|
||||
std::function<void()> fn) {
|
||||
Task t = env_.CreateTask(std::move(fn));
|
||||
@ -384,30 +460,37 @@ class RunHandlerThreadPool {
|
||||
return;
|
||||
}
|
||||
thread_data_[tid].thread_work_sources.resize(0);
|
||||
thread_data_[tid].thread_work_sources.emplace_back(
|
||||
thread_work_sources[start_request_idx]);
|
||||
// The number of shards for the queue. Threads in each shard will prioritize
|
||||
// different thread_work_sources. Increase the number of shards could
|
||||
// decrease the contention in the queue.
|
||||
// For example, when num_shards == 1:
|
||||
// thread_work_sources are ordered as start_request_idx, 0, 1, 2, 3, 4 ...
|
||||
// for all threads.
|
||||
// When num_shards == 2:
|
||||
// thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
|
||||
// 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
|
||||
// 4... for the other half of the threads.
|
||||
int num_shards = ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
|
||||
int token = tid % num_shards;
|
||||
for (int i = 0; i < num_shards; ++i) {
|
||||
for (int j = token; j < thread_work_sources.size(); j += num_shards) {
|
||||
if (j != start_request_idx) {
|
||||
thread_data_[tid].thread_work_sources.emplace_back(
|
||||
thread_work_sources[j]);
|
||||
}
|
||||
|
||||
if (use_sub_thread_pool_) {
|
||||
for (int i = 0; i < thread_work_sources.size(); ++i) {
|
||||
thread_data_[tid].thread_work_sources.emplace_back(
|
||||
thread_work_sources[i]);
|
||||
}
|
||||
token = (token + 1) % num_shards;
|
||||
} else {
|
||||
thread_data_[tid].thread_work_sources.emplace_back(
|
||||
thread_work_sources[start_request_idx]);
|
||||
// The number of shards for the queue. Threads in each shard will
|
||||
// prioritize different thread_work_sources. Increase the number of shards
|
||||
// could decrease the contention in the queue. For example, when
|
||||
// num_shards == 1: thread_work_sources are ordered as start_request_idx,
|
||||
// 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2:
|
||||
// thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
|
||||
// 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
|
||||
// 4... for the other half of the threads.
|
||||
int num_shards =
|
||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
|
||||
int token = tid % num_shards;
|
||||
for (int i = 0; i < num_shards; ++i) {
|
||||
for (int j = token; j < thread_work_sources.size(); j += num_shards) {
|
||||
if (j != start_request_idx) {
|
||||
thread_data_[tid].thread_work_sources.emplace_back(
|
||||
thread_work_sources[j]);
|
||||
}
|
||||
}
|
||||
token = (token + 1) % num_shards;
|
||||
}
|
||||
thread_data_[tid].sources_not_empty.notify_all();
|
||||
}
|
||||
thread_data_[tid].sources_not_empty.notify_all();
|
||||
}
|
||||
|
||||
PerThread* GetPerThread() {
|
||||
@ -434,13 +517,26 @@ class RunHandlerThreadPool {
|
||||
|
||||
void WorkerLoop(int thread_id, bool may_steal_blocking_work);
|
||||
|
||||
// Search tasks from Requets range searching_range_start to
|
||||
// searching_range_end. If there is no tasks in the search range and
|
||||
// may_steal_blocking_work is true, then search from all reuqests.
|
||||
Task FindTask(
|
||||
int searching_range_start, int searching_range_end, int thread_id,
|
||||
int sub_thread_pool_id, int max_blocking_inflight,
|
||||
bool may_steal_blocking_work,
|
||||
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
|
||||
bool* task_from_blocking_queue, ThreadWorkSource** tws);
|
||||
|
||||
void WaitForWork(bool is_blocking, int thread_id,
|
||||
int32 max_blocking_inflight);
|
||||
|
||||
void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);
|
||||
|
||||
private:
|
||||
struct ThreadData {
|
||||
ThreadData()
|
||||
: version(0),
|
||||
current_index(0),
|
||||
thread_work_sources(static_cast<int32>(
|
||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
||||
kMaxConcurrentHandlers))) {}
|
||||
@ -448,7 +544,9 @@ class RunHandlerThreadPool {
|
||||
uint64 version;
|
||||
condition_variable sources_not_empty;
|
||||
std::unique_ptr<Thread> thread;
|
||||
int current_index;
|
||||
Eigen::MaxSizeVector<ThreadWorkSource*> thread_work_sources GUARDED_BY(mu);
|
||||
int sub_thread_pool_id;
|
||||
};
|
||||
|
||||
const int num_threads_;
|
||||
@ -458,8 +556,58 @@ class RunHandlerThreadPool {
|
||||
RunHandlerEnvironment env_;
|
||||
std::atomic<bool> cancelled_;
|
||||
string name_;
|
||||
Eigen::MaxSizeVector<mutex>* waiters_mu_;
|
||||
Eigen::MaxSizeVector<Waiter>* queue_waiters_;
|
||||
|
||||
bool use_sub_thread_pool_;
|
||||
std::vector<int> num_threads_in_sub_thread_pool_;
|
||||
|
||||
// Threads in each sub thread pool will search tasks from the given
|
||||
// start_request_percentage to end_request_percentage in a round robin
|
||||
// fashion.
|
||||
std::vector<double> sub_thread_pool_start_request_percentage_;
|
||||
std::vector<double> sub_thread_pool_end_request_percentage_;
|
||||
};
|
||||
|
||||
Task RunHandlerThreadPool::FindTask(
|
||||
int searching_range_start, int searching_range_end, int thread_id,
|
||||
int sub_thread_pool_id, int max_blocking_inflight,
|
||||
bool may_steal_blocking_work,
|
||||
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
|
||||
bool* task_from_blocking_queue, ThreadWorkSource** tws) {
|
||||
Task t;
|
||||
int current_index = thread_data_[thread_id].current_index;
|
||||
*task_from_blocking_queue = false;
|
||||
|
||||
// TODO(chaox): Chagne the search algorithm from round robin to random
|
||||
// walk.
|
||||
for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
|
||||
if (current_index >= searching_range_end) {
|
||||
current_index = searching_range_start;
|
||||
}
|
||||
*tws = thread_work_sources[current_index];
|
||||
++current_index;
|
||||
|
||||
// For blocking thread, search for blocking tasks first.
|
||||
if (may_steal_blocking_work &&
|
||||
(*tws)->GetInflightTaskCount(true) < max_blocking_inflight) {
|
||||
t = (*tws)->PopBlockingTask();
|
||||
if (t.f) {
|
||||
*task_from_blocking_queue = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Search for non-blocking tasks.
|
||||
t = (*tws)->PopNonBlockingTask(thread_id, true);
|
||||
if (t.f) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
thread_data_[thread_id].current_index = current_index;
|
||||
return t;
|
||||
}
|
||||
|
||||
// Main worker thread loop.
|
||||
void RunHandlerThreadPool::WorkerLoop(int thread_id,
|
||||
bool may_steal_blocking_work) {
|
||||
@ -474,11 +622,47 @@ void RunHandlerThreadPool::WorkerLoop(int thread_id,
|
||||
bool task_from_blocking_queue = true;
|
||||
Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
|
||||
&thread_data_[thread_id].thread_work_sources;
|
||||
{
|
||||
int sub_thread_pool_id;
|
||||
if (use_sub_thread_pool_) {
|
||||
// The mutex is not hot since its per thread and can only be held
|
||||
// by some other thread when a session run starts/finishes.
|
||||
mutex_lock l(thread_data_[thread_id].mu);
|
||||
sub_thread_pool_id = thread_data_[thread_id].sub_thread_pool_id;
|
||||
int active_requests = thread_work_sources->size();
|
||||
if (may_steal_blocking_work) {
|
||||
// Each thread will first look for tasks from requests that belongs to
|
||||
// its sub thread pool.
|
||||
t = FindTask(
|
||||
active_requests *
|
||||
sub_thread_pool_start_request_percentage_[sub_thread_pool_id],
|
||||
active_requests *
|
||||
sub_thread_pool_end_request_percentage_[sub_thread_pool_id],
|
||||
thread_id, sub_thread_pool_id, kMaxBlockingInflight,
|
||||
/*may_steal_blocking_work=*/true, *thread_work_sources,
|
||||
&task_from_blocking_queue, &tws);
|
||||
if (!t.f) {
|
||||
// Search from all requests if the thread cannot find tasks from
|
||||
// requests that belong to its own sub thread pool.
|
||||
t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
|
||||
kMaxBlockingInflight,
|
||||
/*may_steal_blocking_work=*/true, *thread_work_sources,
|
||||
&task_from_blocking_queue, &tws);
|
||||
}
|
||||
} else {
|
||||
// For non-blocking threads, it will always search from all pending
|
||||
// requests.
|
||||
t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
|
||||
kMaxBlockingInflight,
|
||||
/*may_steal_blocking_work=*/false, *thread_work_sources,
|
||||
&task_from_blocking_queue, &tws);
|
||||
}
|
||||
} else {
|
||||
// The mutex is not hot since its per thread and can only be held
|
||||
// by some other thread when a session run starts/finishes.
|
||||
mutex_lock l(thread_data_[thread_id].mu);
|
||||
|
||||
// TODO(chaox): Refactor the following code to share the logic with
|
||||
// FindTask.
|
||||
for (int i = 0; i < thread_work_sources->size(); ++i) {
|
||||
tws = (*thread_work_sources)[i];
|
||||
// We want a smallish numbers of inter threads since
|
||||
@ -495,20 +679,16 @@ void RunHandlerThreadPool::WorkerLoop(int thread_id,
|
||||
// Always look for any work from the "primary" work source.
|
||||
// This way when we wake up a thread for a new closure we are
|
||||
// guaranteed it can be worked on.
|
||||
for (int j = 0; j < tws->NonBlockingWorkShardingFactor(); ++j) {
|
||||
t = tws->PopNonBlockingTask((j + thread_id) %
|
||||
tws->NonBlockingWorkShardingFactor());
|
||||
if (t.f) {
|
||||
task_from_blocking_queue = false;
|
||||
break;
|
||||
}
|
||||
t = tws->PopNonBlockingTask(thread_id, true);
|
||||
if (t.f) {
|
||||
task_from_blocking_queue = false;
|
||||
break;
|
||||
}
|
||||
if (t.f) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
t = tws->PopNonBlockingTask(thread_id %
|
||||
tws->NonBlockingWorkShardingFactor());
|
||||
t = tws->PopNonBlockingTask(thread_id, false);
|
||||
if (t.f) {
|
||||
task_from_blocking_queue = false;
|
||||
break;
|
||||
@ -542,12 +722,30 @@ void RunHandlerThreadPool::WorkerLoop(int thread_id,
|
||||
<< (*thread_work_sources)[i]->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight);
|
||||
if (use_sub_thread_pool_) {
|
||||
WaitForWorkInSubThreadPool(may_steal_blocking_work, sub_thread_pool_id);
|
||||
} else {
|
||||
WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RunHandlerThreadPool::WaitForWorkInSubThreadPool(bool is_blocking,
|
||||
int sub_thread_pool_id) {
|
||||
const int kMaxSleepMicros = 250;
|
||||
|
||||
// The non-blocking thread will just sleep.
|
||||
if (!is_blocking) {
|
||||
Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
|
||||
return;
|
||||
}
|
||||
|
||||
thread_local Waiter waiter;
|
||||
WaitOnWaiter(&waiter, &(*queue_waiters_)[sub_thread_pool_id],
|
||||
&(*waiters_mu_)[sub_thread_pool_id], kMaxSleepMicros);
|
||||
}
|
||||
|
||||
void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id,
|
||||
int32 max_blocking_inflight) {
|
||||
const int kMaxSleepMicros = 250;
|
||||
@ -636,16 +834,33 @@ class RunHandlerPool::Impl {
|
||||
explicit Impl(int num_inter_op_threads, int num_intra_op_threads)
|
||||
: max_handlers_(static_cast<int32>(ParamFromEnvWithDefault(
|
||||
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", kMaxConcurrentHandlers))),
|
||||
waiters_mu_(
|
||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
|
||||
queue_waiters_(
|
||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
|
||||
run_handler_thread_pool_(new RunHandlerThreadPool(
|
||||
num_inter_op_threads, num_intra_op_threads, Env::Default(),
|
||||
ThreadOptions(), "tf_run_handler_pool")),
|
||||
ThreadOptions(), "tf_run_handler_pool", &waiters_mu_,
|
||||
&queue_waiters_)),
|
||||
iterations_(0),
|
||||
version_(0) {
|
||||
version_(0),
|
||||
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
|
||||
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
|
||||
std::vector<double>({1}))) {
|
||||
VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
|
||||
for (int i = 0; i < max_handlers_; ++i) {
|
||||
handlers_.emplace_back(new RunHandler::Impl(this));
|
||||
free_handlers_.push_back(handlers_.back().get());
|
||||
}
|
||||
queue_waiters_.resize(
|
||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
|
||||
waiters_mu_.resize(
|
||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
|
||||
for (auto& queue_waiter : queue_waiters_) {
|
||||
queue_waiter.next = &queue_waiter;
|
||||
queue_waiter.prev = &queue_waiter;
|
||||
}
|
||||
run_handler_thread_pool_->Start();
|
||||
}
|
||||
|
||||
~Impl() {
|
||||
@ -693,6 +908,19 @@ class RunHandlerPool::Impl {
|
||||
for (int i = 0; i < num_active_requests; ++i) {
|
||||
(*thread_work_sources)[i] = sorted_active_handlers_[i]->tws();
|
||||
(*thread_work_sources)[i]->SetRank(i);
|
||||
int sub_thread_pool_id =
|
||||
sub_thread_pool_end_request_percentage_.size() - 1;
|
||||
for (int j = 0; j < sub_thread_pool_end_request_percentage_.size();
|
||||
++j) {
|
||||
if (i < num_active_requests *
|
||||
sub_thread_pool_end_request_percentage_[j]) {
|
||||
sub_thread_pool_id = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
(*thread_work_sources)[i]->SetWaiter(
|
||||
&queue_waiters_[sub_thread_pool_id],
|
||||
&waiters_mu_[sub_thread_pool_id]);
|
||||
}
|
||||
version = ++version_;
|
||||
}
|
||||
@ -738,6 +966,19 @@ class RunHandlerPool::Impl {
|
||||
for (int i = 0; i < num_active_requests; ++i) {
|
||||
(*thread_work_sources)[i] = sorted_active_handlers_[i]->tws();
|
||||
(*thread_work_sources)[i]->SetRank(i);
|
||||
int sub_thread_pool_id =
|
||||
sub_thread_pool_end_request_percentage_.size() - 1;
|
||||
for (int j = 0; j < sub_thread_pool_end_request_percentage_.size();
|
||||
++j) {
|
||||
if (i < num_active_requests *
|
||||
sub_thread_pool_end_request_percentage_[j]) {
|
||||
sub_thread_pool_id = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
(*thread_work_sources)[i]->SetWaiter(
|
||||
&queue_waiters_[sub_thread_pool_id],
|
||||
&waiters_mu_[sub_thread_pool_id]);
|
||||
}
|
||||
version = ++version_;
|
||||
LogInfo();
|
||||
@ -759,6 +1000,9 @@ class RunHandlerPool::Impl {
|
||||
// inference).
|
||||
const int max_handlers_;
|
||||
|
||||
Eigen::MaxSizeVector<mutex> waiters_mu_;
|
||||
Eigen::MaxSizeVector<Waiter> queue_waiters_;
|
||||
|
||||
std::unique_ptr<RunHandlerThreadPool> run_handler_thread_pool_;
|
||||
// Thread compatible part used only by lock under RunHandlerPool.
|
||||
// Handlers are sorted by start time.
|
||||
@ -773,6 +1017,7 @@ class RunHandlerPool::Impl {
|
||||
condition_variable one_handler_free_;
|
||||
mutex mu_;
|
||||
int64 version_ GUARDED_BY(mu_);
|
||||
const std::vector<double> sub_thread_pool_end_request_percentage_;
|
||||
};
|
||||
|
||||
void RunHandlerPool::Impl::RecomputePoolStats(
|
||||
|
@ -24,11 +24,17 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/synchronization/barrier.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -72,5 +78,132 @@ TEST(RunHandlerUtilTest, TestBasicScheduling) {
|
||||
counter.Wait();
|
||||
}
|
||||
|
||||
SessionOptions DefaultSessionOptions() {
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
return options;
|
||||
}
|
||||
|
||||
std::unique_ptr<Session> CreateSession() {
|
||||
return std::unique_ptr<Session>(NewSession(DefaultSessionOptions()));
|
||||
}
|
||||
|
||||
class RunHandlerTest : public ::testing::Test {
|
||||
public:
|
||||
void Initialize(std::initializer_list<float> a_values) {
|
||||
Graph graph(OpRegistry::Global());
|
||||
|
||||
Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
|
||||
test::FillValues<float>(&a_tensor, a_values);
|
||||
Node* a = test::graph::Constant(&graph, a_tensor);
|
||||
a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
|
||||
a_ = a->name();
|
||||
|
||||
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
|
||||
test::FillValues<float>(&x_tensor, {1, 1});
|
||||
Node* x = test::graph::Constant(&graph, x_tensor);
|
||||
x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
|
||||
x_ = x->name();
|
||||
|
||||
// y = A * x
|
||||
Node* y = test::graph::Matmul(&graph, a, x, false, false);
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
|
||||
y_ = y->name();
|
||||
|
||||
Node* y_neg = test::graph::Unary(&graph, "Neg", y);
|
||||
y_neg_ = y_neg->name();
|
||||
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
|
||||
|
||||
Node* z = test::graph::Unary(&graph, "Identity", y_neg);
|
||||
z_ = z->name();
|
||||
z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
|
||||
|
||||
graph.ToGraphDef(&def_);
|
||||
|
||||
ASSERT_EQ(setenv("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", "2", true), 0);
|
||||
ASSERT_EQ(
|
||||
setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "8,8", true),
|
||||
0);
|
||||
ASSERT_EQ(setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
|
||||
"0,0.4", true),
|
||||
0);
|
||||
ASSERT_EQ(setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
|
||||
"0.4,1", true),
|
||||
0);
|
||||
ASSERT_EQ(setenv("TF_NUM_INTEROP_THREADS", "16", true), 0);
|
||||
}
|
||||
|
||||
string a_;
|
||||
string x_;
|
||||
string y_;
|
||||
string y_neg_;
|
||||
string z_;
|
||||
GraphDef def_;
|
||||
};
|
||||
|
||||
TEST_F(RunHandlerTest, UseRunHandlerPoolEnableSubPool) {
|
||||
Initialize({3, 2, -1, 0});
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
EXPECT_EQ(::tensorflow::Status::OK(), session->Create(def_));
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
|
||||
// Request two targets: one fetch output and one non-fetched output.
|
||||
std::vector<string> output_names = {y_ + ":0"};
|
||||
std::vector<string> target_nodes = {y_neg_};
|
||||
std::vector<Tensor> outputs;
|
||||
|
||||
// Prepares RunOptions and RunMetadata
|
||||
RunOptions run_options;
|
||||
run_options.mutable_experimental()->set_use_run_handler_pool(true);
|
||||
|
||||
Status s = session->Run(run_options, inputs, output_names, target_nodes,
|
||||
&outputs, nullptr);
|
||||
EXPECT_EQ(::tensorflow::Status::OK(), s);
|
||||
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
// The first output should be initialized and have the correct
|
||||
// output.
|
||||
auto mat = outputs[0].matrix<float>();
|
||||
ASSERT_TRUE(outputs[0].IsInitialized());
|
||||
EXPECT_FLOAT_EQ(5.0, mat(0, 0));
|
||||
}
|
||||
|
||||
TEST_F(RunHandlerTest, TestConcurrencyUseRunHandlerPool) {
|
||||
Initialize({1, 2, 3, 4});
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
EXPECT_EQ(::tensorflow::Status::OK(), session->Create(def_));
|
||||
|
||||
RunOptions run_options;
|
||||
run_options.mutable_experimental()->set_use_run_handler_pool(true);
|
||||
|
||||
// Fill in the input and ask for the output
|
||||
thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
|
||||
|
||||
// Run the graph 1000 times in 4 different threads concurrently.
|
||||
std::vector<string> output_names = {y_ + ":0"};
|
||||
auto fn = [&session, output_names, run_options]() {
|
||||
for (int i = 0; i < 1000; ++i) {
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
std::vector<Tensor> outputs;
|
||||
// Run the graph
|
||||
Status s = session->Run(run_options, inputs, output_names, {}, &outputs,
|
||||
nullptr);
|
||||
EXPECT_EQ(::tensorflow::Status::OK(), s);
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
auto mat = outputs[0].matrix<float>();
|
||||
EXPECT_FLOAT_EQ(3.0, mat(0, 0));
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
tp->Schedule(fn);
|
||||
}
|
||||
|
||||
// Wait for the functions to finish.
|
||||
delete tp;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/str_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -29,6 +30,54 @@ double ParamFromEnvWithDefault(const std::string& var_name,
|
||||
return (val && strings::safe_strtod(val, &num)) ? num : default_value;
|
||||
}
|
||||
|
||||
std::vector<double> ParamFromEnvWithDefault(const std::string& var_name,
|
||||
std::vector<double> default_value) {
|
||||
const char* val = std::getenv(var_name.c_str());
|
||||
if (!val) {
|
||||
return default_value;
|
||||
}
|
||||
std::vector<string> splits = str_util::Split(val, ",");
|
||||
std::vector<double> result;
|
||||
result.reserve(splits.size());
|
||||
for (auto& split : splits) {
|
||||
double num;
|
||||
if (strings::safe_strtod(split, &num)) {
|
||||
result.push_back(num);
|
||||
} else {
|
||||
LOG(ERROR) << "Wrong format for " << var_name << ". Use default value.";
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int> ParamFromEnvWithDefault(const std::string& var_name,
|
||||
std::vector<int> default_value) {
|
||||
const char* val = std::getenv(var_name.c_str());
|
||||
if (!val) {
|
||||
return default_value;
|
||||
}
|
||||
std::vector<string> splits = str_util::Split(val, ",");
|
||||
std::vector<int> result;
|
||||
result.reserve(splits.size());
|
||||
for (auto& split : splits) {
|
||||
int num;
|
||||
if (strings::safe_strto32(split, &num)) {
|
||||
result.push_back(num);
|
||||
} else {
|
||||
LOG(ERROR) << "Wrong format for " << var_name << ". Use default value.";
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool ParamFromEnvBoolWithDefault(const std::string& var_name,
|
||||
bool default_value) {
|
||||
const char* val = std::getenv(var_name.c_str());
|
||||
return (val) ? str_util::Lowercase(val) == "true" : default_value;
|
||||
}
|
||||
|
||||
void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
|
||||
int min_threads_per_request,
|
||||
std::vector<std::uint_fast32_t>* start_vec,
|
||||
|
@ -54,10 +54,27 @@ void ComputeInterOpStealingRanges(int num_threads, int min_threads_per_domain,
|
||||
std::vector<int> ChooseRequestsWithExponentialDistribution(
|
||||
int num_active_requests, int num_threads);
|
||||
|
||||
// Loop environment variable named 'var_name' and return the value if it exist
|
||||
// and can be parsed. Return 'default_value' otherwise.
|
||||
// Look up environment variable named 'var_name' and return the value if it
|
||||
// exist and can be parsed. Return 'default_value' otherwise.
|
||||
double ParamFromEnvWithDefault(const std::string& var_name,
|
||||
double default_value);
|
||||
|
||||
// Look up environment variable named 'var_name' and return the value if it
|
||||
// exist and can be parsed. The value must be in format val1,val2... Return
|
||||
// 'default_value' otherwise.
|
||||
std::vector<double> ParamFromEnvWithDefault(const std::string& var_name,
|
||||
std::vector<double> default_value);
|
||||
|
||||
// Look up environment variable named 'var_name' and return the value if it
|
||||
// exist and can be parsed. The value must be in format val1,val2... Return
|
||||
// 'default_value' otherwise.
|
||||
std::vector<int> ParamFromEnvWithDefault(const std::string& var_name,
|
||||
std::vector<int> default_value);
|
||||
|
||||
// Look up environment variable named 'var_name' and return the value if it
|
||||
// exist and can be parsed. Return 'default_value' otherwise.
|
||||
bool ParamFromEnvBoolWithDefault(const std::string& var_name,
|
||||
bool default_value);
|
||||
|
||||
} // end namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
|
||||
|
@ -124,5 +124,44 @@ TEST(RunHandlerUtilTest, TestExponentialRequestDistribution) {
|
||||
ASSERT_EQ(actual_distribution, expected_distribution);
|
||||
}
|
||||
|
||||
TEST(RunHandlerUtilTest, TestParamFromEnvWithDefault) {
|
||||
std::vector<double> result = ParamFromEnvWithDefault(
|
||||
"RUN_HANDLER_TEST_ENV", std::vector<double>{0, 0, 0});
|
||||
EXPECT_EQ(result.size(), 3);
|
||||
EXPECT_EQ(result[0], 0);
|
||||
EXPECT_EQ(result[1], 0);
|
||||
EXPECT_EQ(result[2], 0);
|
||||
|
||||
std::vector<int> result2 = ParamFromEnvWithDefault("RUN_HANDLER_TEST_ENV",
|
||||
std::vector<int>{0, 0, 0});
|
||||
EXPECT_EQ(result2.size(), 3);
|
||||
EXPECT_EQ(result2[0], 0);
|
||||
EXPECT_EQ(result2[1], 0);
|
||||
EXPECT_EQ(result2[2], 0);
|
||||
|
||||
bool result3 =
|
||||
ParamFromEnvBoolWithDefault("RUN_HANDLER_TEST_ENV_BOOL", false);
|
||||
EXPECT_EQ(result3, false);
|
||||
|
||||
// Set environment variable.
|
||||
EXPECT_EQ(setenv("RUN_HANDLER_TEST_ENV", "1,2,3", true), 0);
|
||||
result = ParamFromEnvWithDefault("RUN_HANDLER_TEST_ENV",
|
||||
std::vector<double>{0, 0, 0});
|
||||
EXPECT_EQ(result.size(), 3);
|
||||
EXPECT_EQ(result[0], 1);
|
||||
EXPECT_EQ(result[1], 2);
|
||||
EXPECT_EQ(result[2], 3);
|
||||
result2 = ParamFromEnvWithDefault("RUN_HANDLER_TEST_ENV",
|
||||
std::vector<int>{0, 0, 0});
|
||||
EXPECT_EQ(result.size(), 3);
|
||||
EXPECT_EQ(result2[0], 1);
|
||||
EXPECT_EQ(result2[1], 2);
|
||||
EXPECT_EQ(result2[2], 3);
|
||||
|
||||
EXPECT_EQ(setenv("RUN_HANDLER_TEST_ENV_BOOL", "true", true), 0);
|
||||
result3 = ParamFromEnvBoolWithDefault("RUN_HANDLER_TEST_ENV_BOOL", false);
|
||||
EXPECT_EQ(result3, true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user