Refactor code for run handler thread pool, and add some unit tests.

PiperOrigin-RevId: 300835548
Change-Id: Ieb24ef54a327657af6452aaad893cb39fa02bd1c
This commit is contained in:
Chao Xie 2020-03-13 15:12:41 -07:00 committed by TensorFlower Gardener
parent 8e80c097df
commit f4cd9ee784
3 changed files with 1057 additions and 492 deletions

View File

@ -41,79 +41,52 @@ namespace {
static constexpr int32 kMaxConcurrentHandlers = 128;
// LINT.ThenChange(//tensorflow/core/framework/run_handler_test.cc)
// TODO(azaks): Refactor with thread:ThreadPool
class RunHandlerEnvironment {
typedef Thread EnvThread;
struct TaskImpl {
std::function<void()> f;
Context context;
uint64 trace_id;
};
Env* const env_;
const ThreadOptions thread_options_;
const string name_;
public:
struct Task {
std::unique_ptr<TaskImpl> f;
};
RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options,
const string& name)
: env_(env), thread_options_(thread_options), name_(name) {}
EnvThread* CreateThread(std::function<void()> f) {
return env_->StartThread(thread_options_, name_, [=]() {
// Set the processor flag to flush denormals to zero.
port::ScopedFlushDenormal flush;
// Set the processor rounding mode to ROUND TO NEAREST.
port::ScopedSetRound round(FE_TONEAREST);
if (thread_options_.numa_node != port::kNUMANoAffinity) {
port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
}
f();
});
}
Task CreateTask(std::function<void()> f) {
uint64 id = 0;
if (tracing::EventCollector::IsEnabled()) {
id = tracing::GetUniqueArg();
tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
}
return Task{
std::unique_ptr<TaskImpl>(new TaskImpl{
std::move(f),
Context(ContextKind::kThread),
id,
}),
};
}
void ExecuteTask(const Task& t) {
WithContext wc(t.f->context);
tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
t.f->trace_id);
t.f->f();
}
};
typedef typename RunHandlerEnvironment::Task Task;
typedef typename internal::RunHandlerEnvironment::Task Task;
typedef Eigen::RunQueue<Task, 1024> Queue;
// To reduce cache misses, we use a doubly-linked list of Waiter structs and
// queue them in LIFO order rather than the FIFO order used by a single
// condition variable.
struct Waiter {
Waiter() {
next = this;
prev = this;
} // namespace
namespace internal {
RunHandlerEnvironment::RunHandlerEnvironment(
Env* env, const ThreadOptions& thread_options, const string& name)
: env_(env), thread_options_(thread_options), name_(name) {}
RunHandlerEnvironment::EnvThread* RunHandlerEnvironment::CreateThread(
std::function<void()> f) {
return env_->StartThread(thread_options_, name_, [=]() {
// Set the processor flag to flush denormals to zero.
port::ScopedFlushDenormal flush;
// Set the processor rounding mode to ROUND TO NEAREST.
port::ScopedSetRound round(FE_TONEAREST);
if (thread_options_.numa_node != port::kNUMANoAffinity) {
port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
}
f();
});
}
RunHandlerEnvironment::Task RunHandlerEnvironment::CreateTask(
std::function<void()> f) {
uint64 id = 0;
if (tracing::EventCollector::IsEnabled()) {
id = tracing::GetUniqueArg();
tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
}
condition_variable cv;
mutex mu;
Waiter* next;
Waiter* prev;
};
return Task{
std::unique_ptr<TaskImpl>(new TaskImpl{
std::move(f),
Context(ContextKind::kThread),
id,
}),
};
}
void RunHandlerEnvironment::ExecuteTask(const Task& t) {
WithContext wc(t.f->context);
tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
t.f->trace_id);
t.f->f();
}
void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
int max_sleep_micros) {
@ -150,442 +123,359 @@ void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
}
}
class ThreadWorkSource {
public:
ThreadWorkSource()
: non_blocking_work_sharding_factor_(
static_cast<int32>(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))),
non_blocking_work_queues_(non_blocking_work_sharding_factor_),
blocking_inflight_(0),
non_blocking_inflight_(0),
traceme_id_(0),
version_(0),
sub_thread_pool_waiter_(nullptr) {
queue_waiters_.next = &queue_waiters_;
queue_waiters_.prev = &queue_waiters_;
for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) {
non_blocking_work_queues_.emplace_back(new NonBlockingQueue());
ThreadWorkSource::ThreadWorkSource()
: non_blocking_work_sharding_factor_(
static_cast<int32>(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))),
non_blocking_work_queues_(non_blocking_work_sharding_factor_),
blocking_inflight_(0),
non_blocking_inflight_(0),
traceme_id_(0),
version_(0),
sub_thread_pool_waiter_(nullptr) {
queue_waiters_.next = &queue_waiters_;
queue_waiters_.prev = &queue_waiters_;
for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) {
non_blocking_work_queues_.emplace_back(new NonBlockingQueue());
}
}
ThreadWorkSource::~ThreadWorkSource() {
for (int i = 0; i < non_blocking_work_queues_.size(); ++i) {
delete non_blocking_work_queues_[i];
}
}
Task ThreadWorkSource::EnqueueTask(Task t, bool is_blocking) {
mutex* mu = nullptr;
Queue* task_queue = nullptr;
thread_local int64 closure_counter = 0;
if (!is_blocking) {
int queue_index = ++closure_counter % non_blocking_work_sharding_factor_;
task_queue = &(non_blocking_work_queues_[queue_index]->queue);
mu = &non_blocking_work_queues_[queue_index]->queue_op_mu;
} else {
task_queue = &blocking_work_queue_;
mu = &blocking_queue_op_mu_;
}
{
mutex_lock l(*mu);
// For a given queue, only one thread can call PushFront.
t = task_queue->PushFront(std::move(t));
}
Waiter* w = nullptr;
bool use_sub_thread_pool =
ParamFromEnvBoolWithDefault("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
Waiter* waiter_queue;
mutex* waiter_queue_mu;
if (use_sub_thread_pool) {
// When we use multiple sub thread pools, free threads wait on sub
// thread pool waiting queues. Wake up threads from sub thread waiting
// queues.
// The waiting queues are defined at RunHandlerPool.
// Get the waiter_queue and corresponding mutex. Note, the thread work
// source may change afterwards if a new request comes or an old request
// finishes.
tf_shared_lock lock(run_handler_waiter_mu_);
waiter_queue = sub_thread_pool_waiter_;
waiter_queue_mu = sub_thread_pool_waiter_mu_;
} else {
waiter_queue = &queue_waiters_;
waiter_queue_mu = &waiters_mu_;
}
{
mutex_lock l(*waiter_queue_mu);
if (waiter_queue->next != waiter_queue) {
// Remove waiter from the LIFO queue
w = waiter_queue->next;
CHECK(w->prev != w); // Crash OK.
CHECK(w->next != w); // Crash OK.
w->next->prev = w->prev;
w->prev->next = w->next;
// Use `w->next == &w` to indicate that the waiter has been removed
// from the queue.
w->next = w;
w->prev = w;
}
}
if (w != nullptr) {
// We call notify_one() without any locks, so we can miss notifications.
// The wake up logic is best effort and a thread will wake in short
// period of time in case a notification is missed.
w->cv.notify_one();
}
VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from "
<< traceme_id_.load(std::memory_order_relaxed);
return t;
}
Task ThreadWorkSource::PopBlockingTask() {
return blocking_work_queue_.PopBack();
}
Task ThreadWorkSource::PopNonBlockingTask(int start_index,
bool search_from_all_queue) {
Task t;
unsigned sharding_factor = NonBlockingWorkShardingFactor();
for (unsigned j = 0; j < sharding_factor; ++j) {
t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
->queue.PopBack();
if (t.f) {
return t;
}
if (!search_from_all_queue) {
break;
}
}
return t;
}
void ThreadWorkSource::WaitForWork(int max_sleep_micros) {
thread_local Waiter waiter;
WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
}
int ThreadWorkSource::TaskQueueSize(bool is_blocking) {
if (is_blocking) {
return blocking_work_queue_.Size();
} else {
unsigned total_size = 0;
for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) {
total_size += non_blocking_work_queues_[i]->queue.Size();
}
return total_size;
}
}
int64 ThreadWorkSource::GetTracemeId() {
return traceme_id_.load(std::memory_order_relaxed);
}
void ThreadWorkSource::SetTracemeId(int64 value) { traceme_id_ = value; }
void ThreadWorkSource::SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) {
{
tf_shared_lock lock(run_handler_waiter_mu_);
// Most of the request won't change sub pool for recomputation.
// Optimization for avoiding holding exclusive lock to reduce contention.
if (sub_thread_pool_waiter_ == waiter) {
return;
}
// If the current version is a newer version, no need to update.
if (version_ > version) {
return;
}
}
~ThreadWorkSource() {
for (int i = 0; i < non_blocking_work_queues_.size(); ++i) {
delete non_blocking_work_queues_[i];
}
}
mutex_lock l(run_handler_waiter_mu_);
sub_thread_pool_waiter_ = waiter;
sub_thread_pool_waiter_mu_ = mutex;
version_ = version;
}
Task EnqueueTask(Task t, bool is_blocking) {
mutex* mu = nullptr;
Queue* task_queue = nullptr;
thread_local int64 closure_counter = 0;
int64 ThreadWorkSource::GetInflightTaskCount(bool is_blocking) {
std::atomic<int64>* counter =
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
return counter->load(std::memory_order_relaxed);
}
if (!is_blocking) {
int queue_index = ++closure_counter % non_blocking_work_sharding_factor_;
task_queue = &(non_blocking_work_queues_[queue_index]->queue);
mu = &non_blocking_work_queues_[queue_index]->queue_op_mu;
} else {
task_queue = &blocking_work_queue_;
mu = &blocking_queue_op_mu_;
}
void ThreadWorkSource::IncrementInflightTaskCount(bool is_blocking) {
std::atomic<int64>* counter =
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
counter->fetch_add(1, std::memory_order_relaxed);
}
void ThreadWorkSource::DecrementInflightTaskCount(bool is_blocking) {
std::atomic<int64>* counter =
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
counter->fetch_sub(1, std::memory_order_relaxed);
}
unsigned ThreadWorkSource::NonBlockingWorkShardingFactor() {
return non_blocking_work_sharding_factor_;
}
std::string ThreadWorkSource::ToString() {
return strings::StrCat("traceme_id = ", GetTracemeId(),
", inter queue size = ", TaskQueueSize(true),
", inter inflight = ", GetInflightTaskCount(true),
", intra queue size = ", TaskQueueSize(false),
", intra inflight = ", GetInflightTaskCount(false));
}
RunHandlerThreadPool::RunHandlerThreadPool(
int num_blocking_threads, int num_non_blocking_threads, Env* env,
const ThreadOptions& thread_options, const string& name,
Eigen::MaxSizeVector<mutex>* waiters_mu,
Eigen::MaxSizeVector<Waiter>* queue_waiters)
: num_threads_(num_blocking_threads + num_non_blocking_threads),
num_blocking_threads_(num_blocking_threads),
num_non_blocking_threads_(num_non_blocking_threads),
thread_data_(num_threads_),
env_(env, thread_options, name),
name_(name),
waiters_mu_(waiters_mu),
queue_waiters_(queue_waiters),
use_sub_thread_pool_(ParamFromEnvBoolWithDefault(
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
std::vector<int>({num_blocking_threads / 2,
num_blocking_threads - num_blocking_threads / 2}))),
sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
std::vector<double>({0, 0.4}))),
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
std::vector<double>({0.4, 1}))) {
thread_data_.resize(num_threads_);
VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
<< num_blocking_threads_ << " blocking threads and "
<< num_non_blocking_threads_ << " non-blocking threads.";
}
RunHandlerThreadPool::~RunHandlerThreadPool() {
VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
cancelled_ = true;
for (size_t i = 0; i < thread_data_.size(); ++i) {
{
mutex_lock l(*mu);
// For a given queue, only one thread can call PushFront.
t = task_queue->PushFront(std::move(t));
mutex_lock l(thread_data_[i].mu);
thread_data_[i].sources_not_empty.notify_all();
}
Waiter* w = nullptr;
bool use_sub_thread_pool = ParamFromEnvBoolWithDefault(
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
Waiter* waiter_queue;
mutex* waiter_queue_mu;
if (use_sub_thread_pool) {
// When we use multiple sub thread pools, free threads wait on sub
// thread pool waiting queues. Wake up threads from sub thread waiting
// queues.
// The waiting queues are defined at RunHandlerPool.
// Get the waiter_queue and corresponding mutex. Note, the thread work
// source may change afterwards if a new request comes or an old request
// finishes.
tf_shared_lock lock(run_handler_waiter_mu_);
waiter_queue = sub_thread_pool_waiter_;
waiter_queue_mu = sub_thread_pool_waiter_mu_;
} else {
waiter_queue = &queue_waiters_;
waiter_queue_mu = &waiters_mu_;
}
{
mutex_lock l(*waiter_queue_mu);
if (waiter_queue->next != waiter_queue) {
// Remove waiter from the LIFO queue
w = waiter_queue->next;
CHECK(w->prev != w);
CHECK(w->next != w);
w->next->prev = w->prev;
w->prev->next = w->next;
// Use `w->next == &w` to indicate that the waiter has been removed
// from the queue.
w->next = w;
w->prev = w;
}
}
if (w != nullptr) {
// We call notify_one() without any locks, so we can miss notifications.
// The wake up logic is best effort and a thread will wake in short
// period of time in case a notification is missed.
w->cv.notify_one();
}
VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from "
<< traceme_id_.load(std::memory_order_relaxed);
return t;
thread_data_[i].thread.reset();
}
}
Task PopBlockingTask() { return blocking_work_queue_.PopBack(); }
Task PopNonBlockingTask(int start_index, bool search_from_all_queue) {
Task t;
unsigned sharding_factor = NonBlockingWorkShardingFactor();
for (unsigned j = 0; j < sharding_factor; ++j) {
t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
->queue.PopBack();
if (t.f) {
return t;
}
if (!search_from_all_queue) {
void RunHandlerThreadPool::Start() {
cancelled_ = false;
int num_blocking_threads = num_blocking_threads_;
for (int i = 0; i < num_threads_; i++) {
int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
if (i < num_threads_in_sub_thread_pool_[j]) {
sub_thread_pool_id = j;
break;
}
}
return t;
thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
thread_data_[i].thread.reset(
env_.CreateThread([this, i, num_blocking_threads]() {
WorkerLoop(i, i < num_blocking_threads);
}));
}
}
void WaitForWork(int max_sleep_micros) {
thread_local Waiter waiter;
WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
void RunHandlerThreadPool::StartOneThreadForTesting() {
cancelled_ = false;
thread_data_[0].sub_thread_pool_id = 0;
thread_data_[0].thread.reset(
env_.CreateThread([this]() { WorkerLoop(0, true); }));
}
void RunHandlerThreadPool::AddWorkToQueue(ThreadWorkSource* tws,
bool is_blocking,
std::function<void()> fn) {
Task t = env_.CreateTask(std::move(fn));
t = tws->EnqueueTask(std::move(t), is_blocking);
if (t.f) {
VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for "
<< tws->GetTracemeId();
env_.ExecuteTask(t);
}
}
int TaskQueueSize(bool is_blocking) {
if (is_blocking) {
return blocking_work_queue_.Size();
} else {
unsigned total_size = 0;
for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) {
total_size += non_blocking_work_queues_[i]->queue.Size();
}
return total_size;
}
// TODO(donglin) Change the task steal order to be round-robin such that if
// an attempt to steal task from request i failed, then attempt to steal task
// from the next request in terms of the arrival time. This approach may
// provide better performance due to less lock retention. The drawback is that
// the profiler will be a bit harder to read.
void RunHandlerThreadPool::SetThreadWorkSources(
int tid, int start_request_idx, uint64 version,
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
mutex_lock l(thread_data_[tid].mu);
if (version > thread_data_[tid].new_version) {
thread_data_[tid].new_version = version;
} else {
// A newer version is already updated. No need to update.
return;
}
int64 GetTracemeId() { return traceme_id_.load(std::memory_order_relaxed); }
void SetTracemeId(int64 value) { traceme_id_ = value; }
void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) {
{
tf_shared_lock lock(run_handler_waiter_mu_);
// Most of the request won't change sub pool for recomputation.
// Optimization for avoiding holding exclusive lock to reduce contention.
if (sub_thread_pool_waiter_ == waiter) {
return;
}
// If the current version is a newer version, no need to update.
if (version_ > version) {
return;
}
}
mutex_lock l(run_handler_waiter_mu_);
sub_thread_pool_waiter_ = waiter;
sub_thread_pool_waiter_mu_ = mutex;
version_ = version;
}
int64 GetInflightTaskCount(bool is_blocking) {
std::atomic<int64>* counter =
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
return counter->load(std::memory_order_relaxed);
}
void IncrementInflightTaskCount(bool is_blocking) {
std::atomic<int64>* counter =
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
counter->fetch_add(1, std::memory_order_relaxed);
}
void DecrementInflightTaskCount(bool is_blocking) {
std::atomic<int64>* counter =
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
counter->fetch_sub(1, std::memory_order_relaxed);
}
unsigned NonBlockingWorkShardingFactor() {
return non_blocking_work_sharding_factor_;
}
std::string ToString() {
return strings::StrCat("traceme_id = ", GetTracemeId(),
", inter queue size = ", TaskQueueSize(true),
", inter inflight = ", GetInflightTaskCount(true),
", intra queue size = ", TaskQueueSize(false),
", intra inflight = ", GetInflightTaskCount(false));
}
private:
struct NonBlockingQueue {
mutex queue_op_mu;
char pad[128];
Queue queue;
};
int32 non_blocking_work_sharding_factor_;
Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_;
std::atomic<int64> blocking_inflight_;
std::atomic<int64> non_blocking_inflight_;
Queue blocking_work_queue_;
mutex blocking_queue_op_mu_;
char pad_[128];
mutex waiters_mu_;
Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_);
std::atomic<int64> traceme_id_;
mutex run_handler_waiter_mu_;
uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_);
mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_);
Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_);
};
class RunHandlerThreadPool {
public:
struct PerThread {
constexpr PerThread() : pool(nullptr), thread_id(-1) {}
RunHandlerThreadPool* pool; // Parent pool, or null for normal threads.
int thread_id; // Worker thread index in pool.
};
RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
Env* env, const ThreadOptions& thread_options,
const string& name,
Eigen::MaxSizeVector<mutex>* waiters_mu,
Eigen::MaxSizeVector<Waiter>* queue_waiters)
: num_threads_(num_blocking_threads + num_non_blocking_threads),
num_blocking_threads_(num_blocking_threads),
num_non_blocking_threads_(num_non_blocking_threads),
thread_data_(num_threads_),
env_(env, thread_options, name),
name_(name),
waiters_mu_(waiters_mu),
queue_waiters_(queue_waiters),
use_sub_thread_pool_(ParamFromEnvBoolWithDefault(
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
std::vector<int>(
{num_blocking_threads / 2,
num_blocking_threads - num_blocking_threads / 2}))),
sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
std::vector<double>({0, 0.4}))),
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
std::vector<double>({0.4, 1}))) {
VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
<< num_blocking_threads_ << " blocking threads and "
<< num_non_blocking_threads_ << " non-blocking threads.";
}
~RunHandlerThreadPool() {
VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
cancelled_ = true;
for (size_t i = 0; i < thread_data_.size(); ++i) {
{
mutex_lock l(thread_data_[i].mu);
thread_data_[i].sources_not_empty.notify_all();
}
thread_data_[i].thread.reset();
}
}
void Start() {
cancelled_ = false;
thread_data_.resize(num_threads_);
int num_blocking_threads = num_blocking_threads_;
for (int i = 0; i < num_threads_; i++) {
int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
if (i < num_threads_in_sub_thread_pool_[j]) {
sub_thread_pool_id = j;
break;
}
}
thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
thread_data_[i].thread.reset(
env_.CreateThread([this, i, num_blocking_threads]() {
WorkerLoop(i, i < num_blocking_threads);
}));
}
}
void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
std::function<void()> fn) {
Task t = env_.CreateTask(std::move(fn));
t = tws->EnqueueTask(std::move(t), is_blocking);
if (t.f) {
VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for "
<< tws->GetTracemeId();
env_.ExecuteTask(t);
}
}
// Set work queues from which the thread 'tid' can steal its work.
// The request with start_request_idx will be attempted first. Other requests
// will be attempted in FIFO order based on their arrival time.
// TODO(donglin) Change the task steal order to be round-robin such that if
// an attempt to steal task from request i failed, then attempt to steal task
// from the next request in terms of the arrival time. This approach may
// provide better performance due to less lock retention. The drawback is that
// the profiler will be a bit harder to read.
void SetThreadWorkSources(
int tid, int start_request_idx, uint64 version,
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
mutex_lock l(thread_data_[tid].mu);
if (version > thread_data_[tid].new_version) {
thread_data_[tid].new_version = version;
} else {
// A newer version is already updated. No need to update.
return;
}
thread_data_[tid].new_thread_work_sources->resize(0);
if (use_sub_thread_pool_) {
for (int i = 0; i < thread_work_sources.size(); ++i) {
thread_data_[tid].new_thread_work_sources->emplace_back(
thread_work_sources[i]);
}
} else {
thread_data_[tid].new_thread_work_sources->resize(0);
if (use_sub_thread_pool_) {
for (int i = 0; i < thread_work_sources.size(); ++i) {
thread_data_[tid].new_thread_work_sources->emplace_back(
thread_work_sources[start_request_idx]);
// The number of shards for the queue. Threads in each shard will
// prioritize different thread_work_sources. Increase the number of shards
// could decrease the contention in the queue. For example, when
// num_shards == 1: thread_work_sources are ordered as start_request_idx,
// 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2:
// thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
// 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
// 4... for the other half of the threads.
int num_shards =
ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
int token = tid % num_shards;
for (int i = 0; i < num_shards; ++i) {
for (int j = token; j < thread_work_sources.size(); j += num_shards) {
if (j != start_request_idx) {
thread_data_[tid].new_thread_work_sources->emplace_back(
thread_work_sources[j]);
}
thread_work_sources[i]);
}
} else {
thread_data_[tid].new_thread_work_sources->emplace_back(
thread_work_sources[start_request_idx]);
// The number of shards for the queue. Threads in each shard will
// prioritize different thread_work_sources. Increase the number of shards
// could decrease the contention in the queue. For example, when
// num_shards == 1: thread_work_sources are ordered as start_request_idx,
// 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2:
// thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
// 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
// 4... for the other half of the threads.
int num_shards = ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
int token = tid % num_shards;
for (int i = 0; i < num_shards; ++i) {
for (int j = token; j < thread_work_sources.size(); j += num_shards) {
if (j != start_request_idx) {
thread_data_[tid].new_thread_work_sources->emplace_back(
thread_work_sources[j]);
}
token = (token + 1) % num_shards;
}
thread_data_[tid].sources_not_empty.notify_all();
token = (token + 1) % num_shards;
}
thread_data_[tid].sources_not_empty.notify_all();
}
}
PerThread* GetPerThread() {
thread_local PerThread per_thread_;
PerThread* pt = &per_thread_;
return pt;
RunHandlerThreadPool::PerThread* RunHandlerThreadPool::GetPerThread() {
thread_local RunHandlerThreadPool::PerThread per_thread_;
RunHandlerThreadPool::PerThread* pt = &per_thread_;
return pt;
}
int RunHandlerThreadPool::CurrentThreadId() const {
const PerThread* pt = const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
if (pt->pool == this) {
return pt->thread_id;
} else {
return -1;
}
}
int CurrentThreadId() const {
const PerThread* pt =
const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
if (pt->pool == this) {
return pt->thread_id;
} else {
return -1;
}
}
int RunHandlerThreadPool::NumThreads() const { return num_threads_; }
int NumThreads() const { return num_threads_; }
int RunHandlerThreadPool::NumBlockingThreads() const {
return num_blocking_threads_;
}
int NumBlockingThreads() const { return num_blocking_threads_; }
int RunHandlerThreadPool::NumNonBlockingThreads() const {
return num_non_blocking_threads_;
}
int NumNonBlockingThreads() const { return num_non_blocking_threads_; }
void WorkerLoop(int thread_id, bool may_steal_blocking_work);
// Search tasks from Requets range searching_range_start to
// searching_range_end. If there is no tasks in the search range and
// may_steal_blocking_work is true, then search from all requests.
Task FindTask(
int searching_range_start, int searching_range_end, int thread_id,
int sub_thread_pool_id, int max_blocking_inflight,
bool may_steal_blocking_work,
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
bool* task_from_blocking_queue, ThreadWorkSource** tws);
void WaitForWork(bool is_blocking, int thread_id,
int32 max_blocking_inflight);
void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);
private:
struct ThreadData {
ThreadData()
: new_version(0),
current_index(0),
new_thread_work_sources(new Eigen::MaxSizeVector<ThreadWorkSource*>(
static_cast<int32>(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
kMaxConcurrentHandlers)))),
current_version(0),
current_thread_work_sources(
new Eigen::MaxSizeVector<ThreadWorkSource*>(
static_cast<int32>(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
kMaxConcurrentHandlers)))) {}
mutex mu;
uint64 new_version;
condition_variable sources_not_empty;
std::unique_ptr<Thread> thread;
int current_index;
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
new_thread_work_sources TF_GUARDED_BY(mu);
uint64 current_version;
// Should only be accessed by one thread.
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
current_thread_work_sources;
int sub_thread_pool_id;
};
const int num_threads_;
const int num_blocking_threads_;
const int num_non_blocking_threads_;
Eigen::MaxSizeVector<ThreadData> thread_data_;
RunHandlerEnvironment env_;
std::atomic<bool> cancelled_;
string name_;
Eigen::MaxSizeVector<mutex>* waiters_mu_;
Eigen::MaxSizeVector<Waiter>* queue_waiters_;
bool use_sub_thread_pool_;
std::vector<int> num_threads_in_sub_thread_pool_;
// Threads in each sub thread pool will search tasks from the given
// start_request_percentage to end_request_percentage in a round robin
// fashion.
std::vector<double> sub_thread_pool_start_request_percentage_;
std::vector<double> sub_thread_pool_end_request_percentage_;
};
RunHandlerThreadPool::ThreadData::ThreadData()
: new_version(0),
current_index(0),
new_thread_work_sources(
new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
kMaxConcurrentHandlers)))),
current_version(0),
current_thread_work_sources(
new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
kMaxConcurrentHandlers)))) {}
Task RunHandlerThreadPool::FindTask(
int searching_range_start, int searching_range_end, int thread_id,
@ -597,10 +487,9 @@ Task RunHandlerThreadPool::FindTask(
int current_index = thread_data_[thread_id].current_index;
*task_from_blocking_queue = false;
// TODO(chaox): Change the search algorithm from round robin to random
// walk.
for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
if (current_index >= searching_range_end) {
if (current_index >= searching_range_end ||
current_index < searching_range_start) {
current_index = searching_range_start;
}
*tws = thread_work_sources[current_index];
@ -821,7 +710,7 @@ void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id,
tws->WaitForWork(kMaxSleepMicros);
}
} // namespace
} // namespace internal
// Contains the concrete implementation of the RunHandler.
// Externally visible RunHandler class simply forwards the work to this one.
@ -847,7 +736,7 @@ class RunHandler::Impl {
RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
ThreadWorkSource* tws() { return &tws_; }
internal::ThreadWorkSource* tws() { return &tws_; }
int64 priority() { return options_.priority(); }
@ -869,7 +758,7 @@ class RunHandler::Impl {
uint64 start_time_us_;
int64 step_id_;
std::unique_ptr<thread::ThreadPoolInterface> thread_pool_interface_;
ThreadWorkSource tws_;
internal::ThreadWorkSource tws_;
RunOptions::Experimental::RunHandlerPoolOptions options_;
};
@ -885,7 +774,7 @@ class RunHandlerPool::Impl {
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
queue_waiters_(
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
run_handler_thread_pool_(new RunHandlerThreadPool(
run_handler_thread_pool_(new internal::RunHandlerThreadPool(
num_inter_op_threads, num_intra_op_threads, Env::Default(),
ThreadOptions(), "tf_run_handler_pool", &waiters_mu_,
&queue_waiters_)),
@ -924,7 +813,7 @@ class RunHandlerPool::Impl {
run_handler_thread_pool_.reset();
}
RunHandlerThreadPool* run_handler_thread_pool() {
internal::RunHandlerThreadPool* run_handler_thread_pool() {
return run_handler_thread_pool_.get();
}
@ -936,10 +825,11 @@ class RunHandlerPool::Impl {
int64 step_id, int64 timeout_in_ms,
const RunOptions::Experimental::RunHandlerPoolOptions& options)
TF_LOCKS_EXCLUDED(mu_) {
thread_local std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
thread_local std::unique_ptr<
Eigen::MaxSizeVector<internal::ThreadWorkSource*>>
thread_work_sources =
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>(
new Eigen::MaxSizeVector<ThreadWorkSource*>(
std::unique_ptr<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>(
new Eigen::MaxSizeVector<internal::ThreadWorkSource*>(
static_cast<int32>(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
kMaxConcurrentHandlers))));
@ -1035,7 +925,8 @@ class RunHandlerPool::Impl {
private:
void RecomputePoolStats(
int num_active_requests, uint64 version,
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources);
const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
thread_work_sources);
void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
@ -1046,9 +937,9 @@ class RunHandlerPool::Impl {
const int max_handlers_;
Eigen::MaxSizeVector<mutex> waiters_mu_;
Eigen::MaxSizeVector<Waiter> queue_waiters_;
Eigen::MaxSizeVector<internal::Waiter> queue_waiters_;
std::unique_ptr<RunHandlerThreadPool> run_handler_thread_pool_;
std::unique_ptr<internal::RunHandlerThreadPool> run_handler_thread_pool_;
// Thread compatible part used only by lock under RunHandlerPool.
// Handlers are sorted by start time.
// TODO(azaks): sort by the remaining latency budget.
@ -1070,7 +961,8 @@ class RunHandlerPool::Impl {
void RunHandlerPool::Impl::RecomputePoolStats(
int num_active_requests, uint64 version,
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
thread_work_sources) {
if (num_active_requests == 0) return;
int sub_thread_pool_id = 0;

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/histogram/histogram.h"
#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/protobuf/config.pb.h"
@ -106,6 +107,208 @@ class RunHandler {
Impl* impl_; // NOT OWNED.
};
namespace internal {
// TODO(azaks): Refactor with thread:ThreadPool
class RunHandlerEnvironment {
typedef Thread EnvThread;
struct TaskImpl {
std::function<void()> f;
Context context;
uint64 trace_id;
};
Env* const env_;
const ThreadOptions thread_options_;
const string name_;
public:
struct Task {
std::unique_ptr<TaskImpl> f;
};
RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options,
const string& name);
EnvThread* CreateThread(std::function<void()> f);
Task CreateTask(std::function<void()> f);
void ExecuteTask(const Task& t);
};
typedef typename RunHandlerEnvironment::Task Task;
typedef Eigen::RunQueue<Task, 1024> Queue;
// To reduce cache misses, we use a doubly-linked list of Waiter structs and
// queue them in LIFO order rather than the FIFO order used by a single
// condition variable.
struct Waiter {
Waiter() {
next = this;
prev = this;
}
condition_variable cv;
mutex mu;
Waiter* next;
Waiter* prev;
};
class ThreadWorkSource {
public:
ThreadWorkSource();
~ThreadWorkSource();
Task EnqueueTask(Task t, bool is_blocking);
Task PopBlockingTask();
Task PopNonBlockingTask(int start_index, bool search_from_all_queue);
void WaitForWork(int max_sleep_micros);
int TaskQueueSize(bool is_blocking);
int64 GetTracemeId();
void SetTracemeId(int64 value);
void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex);
int64 GetInflightTaskCount(bool is_blocking);
void IncrementInflightTaskCount(bool is_blocking);
void DecrementInflightTaskCount(bool is_blocking);
unsigned NonBlockingWorkShardingFactor();
std::string ToString();
private:
struct NonBlockingQueue {
mutex queue_op_mu;
char pad[128];
Queue queue;
};
int32 non_blocking_work_sharding_factor_;
Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_;
std::atomic<int64> blocking_inflight_;
std::atomic<int64> non_blocking_inflight_;
Queue blocking_work_queue_;
mutex blocking_queue_op_mu_;
char pad_[128];
mutex waiters_mu_;
Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_);
std::atomic<int64> traceme_id_;
mutex run_handler_waiter_mu_;
uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_);
mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_);
Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_);
};
class RunHandlerThreadPool {
public:
struct PerThread {
constexpr PerThread() : pool(nullptr), thread_id(-1) {}
RunHandlerThreadPool* pool; // Parent pool, or null for normal threads.
int thread_id; // Worker thread index in pool.
};
RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
Env* env, const ThreadOptions& thread_options,
const string& name,
Eigen::MaxSizeVector<mutex>* waiters_mu,
Eigen::MaxSizeVector<Waiter>* queue_waiters);
~RunHandlerThreadPool();
void Start();
void StartOneThreadForTesting();
void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
std::function<void()> fn);
// Set work queues from which the thread 'tid' can steal its work.
// The request with start_request_idx will be attempted first. Other requests
// will be attempted in FIFO order based on their arrival time.
void SetThreadWorkSources(
int tid, int start_request_idx, uint64 version,
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources);
PerThread* GetPerThread();
int CurrentThreadId() const;
int NumThreads() const;
int NumBlockingThreads() const;
int NumNonBlockingThreads() const;
void WorkerLoop(int thread_id, bool may_steal_blocking_work);
// Search tasks from Requets range searching_range_start to
// searching_range_end. If there is no tasks in the search range and
// may_steal_blocking_work is true, then search from all requests.
Task FindTask(
int searching_range_start, int searching_range_end, int thread_id,
int sub_thread_pool_id, int max_blocking_inflight,
bool may_steal_blocking_work,
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
bool* task_from_blocking_queue, ThreadWorkSource** tws);
void WaitForWork(bool is_blocking, int thread_id,
int32 max_blocking_inflight);
void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);
private:
struct ThreadData {
ThreadData();
mutex mu;
uint64 new_version;
condition_variable sources_not_empty;
std::unique_ptr<Thread> thread;
int current_index;
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
new_thread_work_sources TF_GUARDED_BY(mu);
uint64 current_version;
// Should only be accessed by one thread.
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
current_thread_work_sources;
int sub_thread_pool_id;
};
const int num_threads_;
const int num_blocking_threads_;
const int num_non_blocking_threads_;
Eigen::MaxSizeVector<ThreadData> thread_data_;
internal::RunHandlerEnvironment env_;
std::atomic<bool> cancelled_;
string name_;
Eigen::MaxSizeVector<mutex>* waiters_mu_;
Eigen::MaxSizeVector<Waiter>* queue_waiters_;
bool use_sub_thread_pool_;
std::vector<int> num_threads_in_sub_thread_pool_;
// Threads in each sub thread pool will search tasks from the given
// start_request_percentage to end_request_percentage in a round robin
// fashion.
std::vector<double> sub_thread_pool_start_request_percentage_;
std::vector<double> sub_thread_pool_end_request_percentage_;
};
} // namespace internal
} // end namespace tensorflow.
#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_

View File

@ -113,6 +113,476 @@ TEST(RunHandlerUtilTest, PrioritySchedulingTest) {
EXPECT_EQ(sorted_active_list[3], 1);
}
TEST(RunHandlerThreadPool, EnqueueTask) {
Eigen::MaxSizeVector<mutex> waiters_mu(2);
waiters_mu.resize(2);
Eigen::MaxSizeVector<internal::Waiter> waiters(2);
waiters.resize(2);
internal::RunHandlerThreadPool run_handler_thread_pool(
/*num_blocking_threads=*/0, /*num_non_blocking_threads=*/0,
Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
&waiters);
internal::ThreadWorkSource tws;
int result = 0;
std::function<void()> fn = [&result] { result = 1; };
std::function<void()> fn2 = [&result] { result = 2; };
run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/true, fn);
EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/true), 1);
run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/true, fn2);
EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/true), 2);
tws.PopBlockingTask().f->f();
EXPECT_EQ(result, 1);
tws.PopBlockingTask().f->f();
EXPECT_EQ(result, 2);
run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/false, fn);
EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/false), 1);
run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/false, fn2);
EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/false), 2);
tws.PopNonBlockingTask(0, true).f->f();
EXPECT_EQ(result, 1);
tws.PopNonBlockingTask(0, true).f->f();
EXPECT_EQ(result, 2);
}
TEST(RunHandlerThreadPool, FindTask) {
Eigen::MaxSizeVector<mutex> waiters_mu(2);
waiters_mu.resize(2);
Eigen::MaxSizeVector<internal::Waiter> waiters(2);
waiters.resize(2);
internal::RunHandlerThreadPool run_handler_thread_pool(
/*num_blocking_threads=*/1, /*num_non_blocking_threads=*/0,
Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
&waiters);
Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(5);
thread_work_sources.resize(5);
for (int i = 0; i < 5; ++i) {
thread_work_sources[i] = new internal::ThreadWorkSource();
}
{
// The thread should search the task following round robin fashion.
int result = -1;
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
/*is_blocking=*/true,
[&result] { result = 2; });
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
/*is_blocking=*/true,
[&result] { result = 2; });
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
/*is_blocking=*/true,
[&result] { result = 3; });
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
/*is_blocking=*/true,
[&result] { result = 3; });
const auto find_blocking_task_from_all_handlers =
[&](bool* task_from_blocking_queue, internal::Task* t) {
internal::ThreadWorkSource* tws;
*t = run_handler_thread_pool.FindTask(
/*searching_range_start=*/0, /*searching_range_end=*/5,
/*thread_id=*/0,
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
/*may_steal_blocking_work=*/true, thread_work_sources,
task_from_blocking_queue, &tws);
};
bool task_from_blocking_queue;
internal::Task t;
find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
EXPECT_EQ(task_from_blocking_queue, true);
t.f->f();
EXPECT_EQ(result, 2);
find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
EXPECT_EQ(task_from_blocking_queue, true);
t.f->f();
EXPECT_EQ(result, 3);
find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
EXPECT_EQ(task_from_blocking_queue, true);
t.f->f();
EXPECT_EQ(result, 2);
find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
EXPECT_EQ(task_from_blocking_queue, true);
t.f->f();
EXPECT_EQ(result, 3);
}
{
// Task out of searching range cannot be found.
int result = -1;
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
/*is_blocking=*/true,
[&result] { result = 3; });
const auto find_blocking_task_from_range =
[&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
int range_end) {
internal::ThreadWorkSource* tws;
*t = run_handler_thread_pool.FindTask(
range_start, range_end,
/*thread_id=*/0,
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
/*may_steal_blocking_work=*/true, thread_work_sources,
task_from_blocking_queue, &tws);
};
bool task_from_blocking_queue;
internal::Task t;
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 3);
EXPECT_EQ(t.f, nullptr);
// Clean up the queue.
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
}
{
// The thread should search from start range if the currrent index is
// smaller.
int result = -1;
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
/*is_blocking=*/true,
[&result] { result = 2; });
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
/*is_blocking=*/true,
[&result] { result = 3; });
const auto find_blocking_task_from_range =
[&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
int range_end) {
internal::ThreadWorkSource* tws;
*t = run_handler_thread_pool.FindTask(
range_start, range_end,
/*thread_id=*/0,
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
/*may_steal_blocking_work=*/true, thread_work_sources,
task_from_blocking_queue, &tws);
};
bool task_from_blocking_queue;
internal::Task t;
find_blocking_task_from_range(&task_from_blocking_queue, &t, 3, 5);
EXPECT_EQ(task_from_blocking_queue, true);
t.f->f();
EXPECT_EQ(result, 3);
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
EXPECT_EQ(task_from_blocking_queue, true);
t.f->f();
EXPECT_EQ(result, 2);
}
{
// The thread should search within the range even if the current index
// is larger than searching_range_end;
int result = -1;
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
/*is_blocking=*/true,
[&result] { result = 2; });
const auto find_blocking_task_from_range =
[&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
int range_end) {
internal::ThreadWorkSource* tws;
*t = run_handler_thread_pool.FindTask(
range_start, range_end,
/*thread_id=*/0,
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
/*may_steal_blocking_work=*/true, thread_work_sources,
task_from_blocking_queue, &tws);
};
bool task_from_blocking_queue;
// Make the current index to be 3.
internal::Task t;
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
EXPECT_EQ(task_from_blocking_queue, true);
t.f->f();
EXPECT_EQ(result, 2);
// Search in a smaller range.
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
/*is_blocking=*/true,
[&result] { result = 2; });
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
/*is_blocking=*/true,
[&result] { result = 3; });
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 3);
EXPECT_EQ(task_from_blocking_queue, true);
t.f->f();
EXPECT_EQ(result, 2);
// Clean up the queue.
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
EXPECT_EQ(task_from_blocking_queue, true);
t.f->f();
EXPECT_EQ(result, 3);
}
{
// We prefer blocking task for blocking threads.
int result = -1;
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
/*is_blocking=*/false,
[&result] { result = 2; });
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
/*is_blocking=*/true,
[&result] { result = 2; });
const auto blocking_thread_find_task_from_all_handler =
[&](bool* task_from_blocking_queue, internal::Task* t) {
internal::ThreadWorkSource* tws;
*t = run_handler_thread_pool.FindTask(
/*searching_range_start=*/0, /*searching_range_end=*/5,
/*thread_id=*/0,
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
/*may_steal_blocking_work=*/true, thread_work_sources,
task_from_blocking_queue, &tws);
};
bool task_from_blocking_queue;
internal::Task t;
blocking_thread_find_task_from_all_handler(&task_from_blocking_queue, &t);
EXPECT_EQ(task_from_blocking_queue, true);
t.f->f();
EXPECT_EQ(result, 2);
blocking_thread_find_task_from_all_handler(&task_from_blocking_queue, &t);
EXPECT_EQ(task_from_blocking_queue, false);
t.f->f();
EXPECT_EQ(result, 2);
}
{
// Nonblocking threads can only pick up non-blocking task.
int result = -1;
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
/*is_blocking=*/false,
[&result] { result = 2; });
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
/*is_blocking=*/true,
[&result] { result = 2; });
const auto find_task_from_all_handler = [&](bool* task_from_blocking_queue,
internal::Task* t,
bool is_blocking_thread) {
internal::ThreadWorkSource* tws;
*t = run_handler_thread_pool.FindTask(
/*searching_range_start=*/0, /*searching_range_end=*/5,
/*thread_id=*/0,
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
is_blocking_thread, thread_work_sources, task_from_blocking_queue,
&tws);
};
bool task_from_blocking_queue;
internal::Task t;
find_task_from_all_handler(&task_from_blocking_queue, &t,
/*is_blocking_thread=*/false);
EXPECT_EQ(task_from_blocking_queue, false);
t.f->f();
EXPECT_EQ(result, 2);
find_task_from_all_handler(&task_from_blocking_queue, &t,
/*is_blocking_thread=*/false);
EXPECT_EQ(t.f, nullptr);
// Clean up the queue.
find_task_from_all_handler(&task_from_blocking_queue, &t,
/*is_blocking_thread=*/true);
}
{
// There is a limit for max_blocking_inflight requests.
int result = -1;
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
/*is_blocking=*/true,
[&result] { result = 2; });
const auto find_task_from_all_handler = [&](bool* task_from_blocking_queue,
internal::Task* t,
bool is_blocking_thread) {
internal::ThreadWorkSource* tws;
*t = run_handler_thread_pool.FindTask(
/*searching_range_start=*/0, /*searching_range_end=*/5,
/*thread_id=*/0,
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
is_blocking_thread, thread_work_sources, task_from_blocking_queue,
&tws);
};
bool task_from_blocking_queue;
internal::Task t;
find_task_from_all_handler(&task_from_blocking_queue, &t,
/*is_blocking_thread=*/false);
EXPECT_EQ(task_from_blocking_queue, false);
EXPECT_EQ(t.f, nullptr);
// Clean up the queue.
find_task_from_all_handler(&task_from_blocking_queue, &t,
/*is_blocking_thread=*/true);
}
for (int i = 0; i < 5; ++i) {
delete thread_work_sources[i];
}
}
TEST(RunHandlerThreadPool, RoundRobinExecution) {
// Set up environment for 1 sub thread pool.
setenv("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", "true", true);
setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "1", true);
setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", "0", true);
setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", "1", true);
Eigen::MaxSizeVector<mutex> waiters_mu(1);
waiters_mu.resize(1);
Eigen::MaxSizeVector<internal::Waiter> waiters(1);
waiters.resize(1);
internal::RunHandlerThreadPool* run_handler_thread_pool =
new internal::RunHandlerThreadPool(
/*num_blocking_threads=*/1, /*num_non_blocking_threads=*/0,
Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
&waiters);
Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(3);
thread_work_sources.resize(3);
internal::ThreadWorkSource tws[3];
for (int i = 0; i < 3; ++i) {
tws[i].SetWaiter(1, &waiters[0], &waiters_mu[0]);
thread_work_sources[i] = &tws[i];
}
int result = 0;
mutex mu;
bool ok_to_execute = false;
bool ok_to_validate = false;
condition_variable function_start;
condition_variable function_end;
std::vector<std::function<void()>> fns;
for (int i = 0; i < 3; ++i) {
fns.push_back([&result, &mu, &function_start, &function_end, &ok_to_execute,
&ok_to_validate, i] {
mutex_lock l(mu);
while (!ok_to_execute) {
function_start.wait(l);
}
result = i;
ok_to_execute = false;
ok_to_validate = true;
function_end.notify_one();
});
run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
fns[i]);
run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
fns[i]);
}
run_handler_thread_pool->Start();
run_handler_thread_pool->SetThreadWorkSources(
/*tid=*/0, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
// Validate the execution should be roundrobin.
mutex_lock l(mu);
for (int round = 0; round < 2; ++round) {
for (int i = 0; i < 3; ++i) {
ok_to_execute = true;
function_start.notify_one();
while (!ok_to_validate) {
function_end.wait(l);
}
ok_to_validate = false;
EXPECT_EQ(result, i);
}
}
delete run_handler_thread_pool;
}
TEST(RunHandlerThreadPool, MultipleSubThreadPool) {
// Set up environment for 2 sub thread pools.
setenv("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", "true", true);
setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "2", true);
setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", "0,0.5",
true);
setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", "0.5,1",
true);
Eigen::MaxSizeVector<mutex> waiters_mu(2);
waiters_mu.resize(2);
Eigen::MaxSizeVector<internal::Waiter> waiters(2);
waiters.resize(2);
internal::RunHandlerThreadPool* run_handler_thread_pool =
new internal::RunHandlerThreadPool(
/*num_blocking_threads=*/2, /*num_non_blocking_threads=*/0,
Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
&waiters);
Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(4);
thread_work_sources.resize(4);
internal::ThreadWorkSource tws[4];
for (int i = 0; i < 4; ++i) {
tws[i].SetWaiter(1, &waiters[i / 2], &waiters_mu[i / 2]);
thread_work_sources[i] = &tws[i];
}
int result = 0;
mutex mu;
bool ok_to_execute = false;
bool ok_to_validate = false;
condition_variable function_start;
condition_variable function_end;
std::vector<std::function<void()>> fns;
for (int i = 0; i < 4; ++i) {
fns.push_back([&result, &mu, &function_start, &function_end, &ok_to_execute,
&ok_to_validate, i] {
mutex_lock l(mu);
while (!ok_to_execute) {
function_start.wait(l);
}
result = i;
ok_to_execute = false;
ok_to_validate = true;
function_end.notify_one();
});
run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
fns[i]);
run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
fns[i]);
}
run_handler_thread_pool->StartOneThreadForTesting();
run_handler_thread_pool->SetThreadWorkSources(
/*tid=*/0, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
run_handler_thread_pool->SetThreadWorkSources(
/*tid=*/1, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
// Pick task from the given sub thread pool requests in a round robin fashion.
mutex_lock l(mu);
for (int round = 0; round < 2; ++round) {
for (int i = 0; i < 2; ++i) {
ok_to_execute = true;
function_start.notify_one();
while (!ok_to_validate) {
function_end.wait(l);
}
ok_to_validate = false;
EXPECT_EQ(result, i);
}
}
// Pick task from any task if there is no tasks from the requests in the sub
// thread pool.
for (int i = 0; i < 2; ++i) {
for (int round = 0; round < 2; ++round) {
ok_to_execute = true;
function_start.notify_one();
while (!ok_to_validate) {
function_end.wait(l);
}
ok_to_validate = false;
EXPECT_EQ(result, i + 2);
}
}
delete run_handler_thread_pool;
}
SessionOptions DefaultSessionOptions() {
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;