Refactor code for run handler thread pool, and add some unit tests.
PiperOrigin-RevId: 300835548 Change-Id: Ieb24ef54a327657af6452aaad893cb39fa02bd1c
This commit is contained in:
parent
8e80c097df
commit
f4cd9ee784
@ -41,79 +41,52 @@ namespace {
|
|||||||
static constexpr int32 kMaxConcurrentHandlers = 128;
|
static constexpr int32 kMaxConcurrentHandlers = 128;
|
||||||
// LINT.ThenChange(//tensorflow/core/framework/run_handler_test.cc)
|
// LINT.ThenChange(//tensorflow/core/framework/run_handler_test.cc)
|
||||||
|
|
||||||
// TODO(azaks): Refactor with thread:ThreadPool
|
typedef typename internal::RunHandlerEnvironment::Task Task;
|
||||||
class RunHandlerEnvironment {
|
|
||||||
typedef Thread EnvThread;
|
|
||||||
struct TaskImpl {
|
|
||||||
std::function<void()> f;
|
|
||||||
Context context;
|
|
||||||
uint64 trace_id;
|
|
||||||
};
|
|
||||||
Env* const env_;
|
|
||||||
const ThreadOptions thread_options_;
|
|
||||||
const string name_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
struct Task {
|
|
||||||
std::unique_ptr<TaskImpl> f;
|
|
||||||
};
|
|
||||||
|
|
||||||
RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options,
|
|
||||||
const string& name)
|
|
||||||
: env_(env), thread_options_(thread_options), name_(name) {}
|
|
||||||
|
|
||||||
EnvThread* CreateThread(std::function<void()> f) {
|
|
||||||
return env_->StartThread(thread_options_, name_, [=]() {
|
|
||||||
// Set the processor flag to flush denormals to zero.
|
|
||||||
port::ScopedFlushDenormal flush;
|
|
||||||
// Set the processor rounding mode to ROUND TO NEAREST.
|
|
||||||
port::ScopedSetRound round(FE_TONEAREST);
|
|
||||||
if (thread_options_.numa_node != port::kNUMANoAffinity) {
|
|
||||||
port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
|
|
||||||
}
|
|
||||||
f();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
Task CreateTask(std::function<void()> f) {
|
|
||||||
uint64 id = 0;
|
|
||||||
if (tracing::EventCollector::IsEnabled()) {
|
|
||||||
id = tracing::GetUniqueArg();
|
|
||||||
tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
|
|
||||||
}
|
|
||||||
return Task{
|
|
||||||
std::unique_ptr<TaskImpl>(new TaskImpl{
|
|
||||||
std::move(f),
|
|
||||||
Context(ContextKind::kThread),
|
|
||||||
id,
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
void ExecuteTask(const Task& t) {
|
|
||||||
WithContext wc(t.f->context);
|
|
||||||
tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
|
|
||||||
t.f->trace_id);
|
|
||||||
t.f->f();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef typename RunHandlerEnvironment::Task Task;
|
|
||||||
typedef Eigen::RunQueue<Task, 1024> Queue;
|
typedef Eigen::RunQueue<Task, 1024> Queue;
|
||||||
|
|
||||||
// To reduce cache misses, we use a doubly-linked list of Waiter structs and
|
} // namespace
|
||||||
// queue them in LIFO order rather than the FIFO order used by a single
|
|
||||||
// condition variable.
|
namespace internal {
|
||||||
struct Waiter {
|
RunHandlerEnvironment::RunHandlerEnvironment(
|
||||||
Waiter() {
|
Env* env, const ThreadOptions& thread_options, const string& name)
|
||||||
next = this;
|
: env_(env), thread_options_(thread_options), name_(name) {}
|
||||||
prev = this;
|
|
||||||
|
RunHandlerEnvironment::EnvThread* RunHandlerEnvironment::CreateThread(
|
||||||
|
std::function<void()> f) {
|
||||||
|
return env_->StartThread(thread_options_, name_, [=]() {
|
||||||
|
// Set the processor flag to flush denormals to zero.
|
||||||
|
port::ScopedFlushDenormal flush;
|
||||||
|
// Set the processor rounding mode to ROUND TO NEAREST.
|
||||||
|
port::ScopedSetRound round(FE_TONEAREST);
|
||||||
|
if (thread_options_.numa_node != port::kNUMANoAffinity) {
|
||||||
|
port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
|
||||||
|
}
|
||||||
|
f();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
RunHandlerEnvironment::Task RunHandlerEnvironment::CreateTask(
|
||||||
|
std::function<void()> f) {
|
||||||
|
uint64 id = 0;
|
||||||
|
if (tracing::EventCollector::IsEnabled()) {
|
||||||
|
id = tracing::GetUniqueArg();
|
||||||
|
tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
|
||||||
}
|
}
|
||||||
condition_variable cv;
|
return Task{
|
||||||
mutex mu;
|
std::unique_ptr<TaskImpl>(new TaskImpl{
|
||||||
Waiter* next;
|
std::move(f),
|
||||||
Waiter* prev;
|
Context(ContextKind::kThread),
|
||||||
};
|
id,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunHandlerEnvironment::ExecuteTask(const Task& t) {
|
||||||
|
WithContext wc(t.f->context);
|
||||||
|
tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
|
||||||
|
t.f->trace_id);
|
||||||
|
t.f->f();
|
||||||
|
}
|
||||||
|
|
||||||
void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
|
void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
|
||||||
int max_sleep_micros) {
|
int max_sleep_micros) {
|
||||||
@ -150,442 +123,359 @@ void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class ThreadWorkSource {
|
ThreadWorkSource::ThreadWorkSource()
|
||||||
public:
|
: non_blocking_work_sharding_factor_(
|
||||||
ThreadWorkSource()
|
static_cast<int32>(ParamFromEnvWithDefault(
|
||||||
: non_blocking_work_sharding_factor_(
|
"TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))),
|
||||||
static_cast<int32>(ParamFromEnvWithDefault(
|
non_blocking_work_queues_(non_blocking_work_sharding_factor_),
|
||||||
"TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))),
|
blocking_inflight_(0),
|
||||||
non_blocking_work_queues_(non_blocking_work_sharding_factor_),
|
non_blocking_inflight_(0),
|
||||||
blocking_inflight_(0),
|
traceme_id_(0),
|
||||||
non_blocking_inflight_(0),
|
version_(0),
|
||||||
traceme_id_(0),
|
sub_thread_pool_waiter_(nullptr) {
|
||||||
version_(0),
|
queue_waiters_.next = &queue_waiters_;
|
||||||
sub_thread_pool_waiter_(nullptr) {
|
queue_waiters_.prev = &queue_waiters_;
|
||||||
queue_waiters_.next = &queue_waiters_;
|
for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) {
|
||||||
queue_waiters_.prev = &queue_waiters_;
|
non_blocking_work_queues_.emplace_back(new NonBlockingQueue());
|
||||||
for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) {
|
}
|
||||||
non_blocking_work_queues_.emplace_back(new NonBlockingQueue());
|
}
|
||||||
|
|
||||||
|
ThreadWorkSource::~ThreadWorkSource() {
|
||||||
|
for (int i = 0; i < non_blocking_work_queues_.size(); ++i) {
|
||||||
|
delete non_blocking_work_queues_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Task ThreadWorkSource::EnqueueTask(Task t, bool is_blocking) {
|
||||||
|
mutex* mu = nullptr;
|
||||||
|
Queue* task_queue = nullptr;
|
||||||
|
thread_local int64 closure_counter = 0;
|
||||||
|
|
||||||
|
if (!is_blocking) {
|
||||||
|
int queue_index = ++closure_counter % non_blocking_work_sharding_factor_;
|
||||||
|
task_queue = &(non_blocking_work_queues_[queue_index]->queue);
|
||||||
|
mu = &non_blocking_work_queues_[queue_index]->queue_op_mu;
|
||||||
|
} else {
|
||||||
|
task_queue = &blocking_work_queue_;
|
||||||
|
mu = &blocking_queue_op_mu_;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
mutex_lock l(*mu);
|
||||||
|
// For a given queue, only one thread can call PushFront.
|
||||||
|
t = task_queue->PushFront(std::move(t));
|
||||||
|
}
|
||||||
|
|
||||||
|
Waiter* w = nullptr;
|
||||||
|
bool use_sub_thread_pool =
|
||||||
|
ParamFromEnvBoolWithDefault("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
|
||||||
|
|
||||||
|
Waiter* waiter_queue;
|
||||||
|
mutex* waiter_queue_mu;
|
||||||
|
if (use_sub_thread_pool) {
|
||||||
|
// When we use multiple sub thread pools, free threads wait on sub
|
||||||
|
// thread pool waiting queues. Wake up threads from sub thread waiting
|
||||||
|
// queues.
|
||||||
|
// The waiting queues are defined at RunHandlerPool.
|
||||||
|
// Get the waiter_queue and corresponding mutex. Note, the thread work
|
||||||
|
// source may change afterwards if a new request comes or an old request
|
||||||
|
// finishes.
|
||||||
|
tf_shared_lock lock(run_handler_waiter_mu_);
|
||||||
|
waiter_queue = sub_thread_pool_waiter_;
|
||||||
|
waiter_queue_mu = sub_thread_pool_waiter_mu_;
|
||||||
|
} else {
|
||||||
|
waiter_queue = &queue_waiters_;
|
||||||
|
waiter_queue_mu = &waiters_mu_;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
mutex_lock l(*waiter_queue_mu);
|
||||||
|
if (waiter_queue->next != waiter_queue) {
|
||||||
|
// Remove waiter from the LIFO queue
|
||||||
|
w = waiter_queue->next;
|
||||||
|
|
||||||
|
CHECK(w->prev != w); // Crash OK.
|
||||||
|
CHECK(w->next != w); // Crash OK.
|
||||||
|
|
||||||
|
w->next->prev = w->prev;
|
||||||
|
w->prev->next = w->next;
|
||||||
|
|
||||||
|
// Use `w->next == &w` to indicate that the waiter has been removed
|
||||||
|
// from the queue.
|
||||||
|
w->next = w;
|
||||||
|
w->prev = w;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (w != nullptr) {
|
||||||
|
// We call notify_one() without any locks, so we can miss notifications.
|
||||||
|
// The wake up logic is best effort and a thread will wake in short
|
||||||
|
// period of time in case a notification is missed.
|
||||||
|
w->cv.notify_one();
|
||||||
|
}
|
||||||
|
VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from "
|
||||||
|
<< traceme_id_.load(std::memory_order_relaxed);
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
Task ThreadWorkSource::PopBlockingTask() {
|
||||||
|
return blocking_work_queue_.PopBack();
|
||||||
|
}
|
||||||
|
|
||||||
|
Task ThreadWorkSource::PopNonBlockingTask(int start_index,
|
||||||
|
bool search_from_all_queue) {
|
||||||
|
Task t;
|
||||||
|
unsigned sharding_factor = NonBlockingWorkShardingFactor();
|
||||||
|
for (unsigned j = 0; j < sharding_factor; ++j) {
|
||||||
|
t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
|
||||||
|
->queue.PopBack();
|
||||||
|
if (t.f) {
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
if (!search_from_all_queue) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadWorkSource::WaitForWork(int max_sleep_micros) {
|
||||||
|
thread_local Waiter waiter;
|
||||||
|
WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
|
||||||
|
}
|
||||||
|
|
||||||
|
int ThreadWorkSource::TaskQueueSize(bool is_blocking) {
|
||||||
|
if (is_blocking) {
|
||||||
|
return blocking_work_queue_.Size();
|
||||||
|
} else {
|
||||||
|
unsigned total_size = 0;
|
||||||
|
for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) {
|
||||||
|
total_size += non_blocking_work_queues_[i]->queue.Size();
|
||||||
|
}
|
||||||
|
return total_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 ThreadWorkSource::GetTracemeId() {
|
||||||
|
return traceme_id_.load(std::memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadWorkSource::SetTracemeId(int64 value) { traceme_id_ = value; }
|
||||||
|
|
||||||
|
void ThreadWorkSource::SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) {
|
||||||
|
{
|
||||||
|
tf_shared_lock lock(run_handler_waiter_mu_);
|
||||||
|
// Most of the request won't change sub pool for recomputation.
|
||||||
|
// Optimization for avoiding holding exclusive lock to reduce contention.
|
||||||
|
if (sub_thread_pool_waiter_ == waiter) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// If the current version is a newer version, no need to update.
|
||||||
|
if (version_ > version) {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~ThreadWorkSource() {
|
mutex_lock l(run_handler_waiter_mu_);
|
||||||
for (int i = 0; i < non_blocking_work_queues_.size(); ++i) {
|
sub_thread_pool_waiter_ = waiter;
|
||||||
delete non_blocking_work_queues_[i];
|
sub_thread_pool_waiter_mu_ = mutex;
|
||||||
}
|
version_ = version;
|
||||||
}
|
}
|
||||||
|
|
||||||
Task EnqueueTask(Task t, bool is_blocking) {
|
int64 ThreadWorkSource::GetInflightTaskCount(bool is_blocking) {
|
||||||
mutex* mu = nullptr;
|
std::atomic<int64>* counter =
|
||||||
Queue* task_queue = nullptr;
|
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
|
||||||
thread_local int64 closure_counter = 0;
|
return counter->load(std::memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
if (!is_blocking) {
|
void ThreadWorkSource::IncrementInflightTaskCount(bool is_blocking) {
|
||||||
int queue_index = ++closure_counter % non_blocking_work_sharding_factor_;
|
std::atomic<int64>* counter =
|
||||||
task_queue = &(non_blocking_work_queues_[queue_index]->queue);
|
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
|
||||||
mu = &non_blocking_work_queues_[queue_index]->queue_op_mu;
|
counter->fetch_add(1, std::memory_order_relaxed);
|
||||||
} else {
|
}
|
||||||
task_queue = &blocking_work_queue_;
|
|
||||||
mu = &blocking_queue_op_mu_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
void ThreadWorkSource::DecrementInflightTaskCount(bool is_blocking) {
|
||||||
|
std::atomic<int64>* counter =
|
||||||
|
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
|
||||||
|
counter->fetch_sub(1, std::memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned ThreadWorkSource::NonBlockingWorkShardingFactor() {
|
||||||
|
return non_blocking_work_sharding_factor_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ThreadWorkSource::ToString() {
|
||||||
|
return strings::StrCat("traceme_id = ", GetTracemeId(),
|
||||||
|
", inter queue size = ", TaskQueueSize(true),
|
||||||
|
", inter inflight = ", GetInflightTaskCount(true),
|
||||||
|
", intra queue size = ", TaskQueueSize(false),
|
||||||
|
", intra inflight = ", GetInflightTaskCount(false));
|
||||||
|
}
|
||||||
|
|
||||||
|
RunHandlerThreadPool::RunHandlerThreadPool(
|
||||||
|
int num_blocking_threads, int num_non_blocking_threads, Env* env,
|
||||||
|
const ThreadOptions& thread_options, const string& name,
|
||||||
|
Eigen::MaxSizeVector<mutex>* waiters_mu,
|
||||||
|
Eigen::MaxSizeVector<Waiter>* queue_waiters)
|
||||||
|
: num_threads_(num_blocking_threads + num_non_blocking_threads),
|
||||||
|
num_blocking_threads_(num_blocking_threads),
|
||||||
|
num_non_blocking_threads_(num_non_blocking_threads),
|
||||||
|
thread_data_(num_threads_),
|
||||||
|
env_(env, thread_options, name),
|
||||||
|
name_(name),
|
||||||
|
waiters_mu_(waiters_mu),
|
||||||
|
queue_waiters_(queue_waiters),
|
||||||
|
use_sub_thread_pool_(ParamFromEnvBoolWithDefault(
|
||||||
|
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
|
||||||
|
num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
|
||||||
|
"TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
|
||||||
|
std::vector<int>({num_blocking_threads / 2,
|
||||||
|
num_blocking_threads - num_blocking_threads / 2}))),
|
||||||
|
sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
|
||||||
|
"TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
|
||||||
|
std::vector<double>({0, 0.4}))),
|
||||||
|
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
|
||||||
|
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
|
||||||
|
std::vector<double>({0.4, 1}))) {
|
||||||
|
thread_data_.resize(num_threads_);
|
||||||
|
VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
|
||||||
|
<< num_blocking_threads_ << " blocking threads and "
|
||||||
|
<< num_non_blocking_threads_ << " non-blocking threads.";
|
||||||
|
}
|
||||||
|
|
||||||
|
RunHandlerThreadPool::~RunHandlerThreadPool() {
|
||||||
|
VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
|
||||||
|
|
||||||
|
cancelled_ = true;
|
||||||
|
for (size_t i = 0; i < thread_data_.size(); ++i) {
|
||||||
{
|
{
|
||||||
mutex_lock l(*mu);
|
mutex_lock l(thread_data_[i].mu);
|
||||||
// For a given queue, only one thread can call PushFront.
|
thread_data_[i].sources_not_empty.notify_all();
|
||||||
t = task_queue->PushFront(std::move(t));
|
|
||||||
}
|
}
|
||||||
|
thread_data_[i].thread.reset();
|
||||||
Waiter* w = nullptr;
|
|
||||||
bool use_sub_thread_pool = ParamFromEnvBoolWithDefault(
|
|
||||||
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
|
|
||||||
|
|
||||||
Waiter* waiter_queue;
|
|
||||||
mutex* waiter_queue_mu;
|
|
||||||
if (use_sub_thread_pool) {
|
|
||||||
// When we use multiple sub thread pools, free threads wait on sub
|
|
||||||
// thread pool waiting queues. Wake up threads from sub thread waiting
|
|
||||||
// queues.
|
|
||||||
// The waiting queues are defined at RunHandlerPool.
|
|
||||||
// Get the waiter_queue and corresponding mutex. Note, the thread work
|
|
||||||
// source may change afterwards if a new request comes or an old request
|
|
||||||
// finishes.
|
|
||||||
tf_shared_lock lock(run_handler_waiter_mu_);
|
|
||||||
waiter_queue = sub_thread_pool_waiter_;
|
|
||||||
waiter_queue_mu = sub_thread_pool_waiter_mu_;
|
|
||||||
} else {
|
|
||||||
waiter_queue = &queue_waiters_;
|
|
||||||
waiter_queue_mu = &waiters_mu_;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
mutex_lock l(*waiter_queue_mu);
|
|
||||||
if (waiter_queue->next != waiter_queue) {
|
|
||||||
// Remove waiter from the LIFO queue
|
|
||||||
w = waiter_queue->next;
|
|
||||||
|
|
||||||
CHECK(w->prev != w);
|
|
||||||
CHECK(w->next != w);
|
|
||||||
|
|
||||||
w->next->prev = w->prev;
|
|
||||||
w->prev->next = w->next;
|
|
||||||
|
|
||||||
// Use `w->next == &w` to indicate that the waiter has been removed
|
|
||||||
// from the queue.
|
|
||||||
w->next = w;
|
|
||||||
w->prev = w;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (w != nullptr) {
|
|
||||||
// We call notify_one() without any locks, so we can miss notifications.
|
|
||||||
// The wake up logic is best effort and a thread will wake in short
|
|
||||||
// period of time in case a notification is missed.
|
|
||||||
w->cv.notify_one();
|
|
||||||
}
|
|
||||||
VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from "
|
|
||||||
<< traceme_id_.load(std::memory_order_relaxed);
|
|
||||||
return t;
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Task PopBlockingTask() { return blocking_work_queue_.PopBack(); }
|
void RunHandlerThreadPool::Start() {
|
||||||
|
cancelled_ = false;
|
||||||
Task PopNonBlockingTask(int start_index, bool search_from_all_queue) {
|
int num_blocking_threads = num_blocking_threads_;
|
||||||
Task t;
|
for (int i = 0; i < num_threads_; i++) {
|
||||||
unsigned sharding_factor = NonBlockingWorkShardingFactor();
|
int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
|
||||||
for (unsigned j = 0; j < sharding_factor; ++j) {
|
for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
|
||||||
t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
|
if (i < num_threads_in_sub_thread_pool_[j]) {
|
||||||
->queue.PopBack();
|
sub_thread_pool_id = j;
|
||||||
if (t.f) {
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
if (!search_from_all_queue) {
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return t;
|
thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
|
||||||
|
thread_data_[i].thread.reset(
|
||||||
|
env_.CreateThread([this, i, num_blocking_threads]() {
|
||||||
|
WorkerLoop(i, i < num_blocking_threads);
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void WaitForWork(int max_sleep_micros) {
|
void RunHandlerThreadPool::StartOneThreadForTesting() {
|
||||||
thread_local Waiter waiter;
|
cancelled_ = false;
|
||||||
WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
|
thread_data_[0].sub_thread_pool_id = 0;
|
||||||
|
thread_data_[0].thread.reset(
|
||||||
|
env_.CreateThread([this]() { WorkerLoop(0, true); }));
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunHandlerThreadPool::AddWorkToQueue(ThreadWorkSource* tws,
|
||||||
|
bool is_blocking,
|
||||||
|
std::function<void()> fn) {
|
||||||
|
Task t = env_.CreateTask(std::move(fn));
|
||||||
|
t = tws->EnqueueTask(std::move(t), is_blocking);
|
||||||
|
if (t.f) {
|
||||||
|
VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for "
|
||||||
|
<< tws->GetTracemeId();
|
||||||
|
env_.ExecuteTask(t);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int TaskQueueSize(bool is_blocking) {
|
// TODO(donglin) Change the task steal order to be round-robin such that if
|
||||||
if (is_blocking) {
|
// an attempt to steal task from request i failed, then attempt to steal task
|
||||||
return blocking_work_queue_.Size();
|
// from the next request in terms of the arrival time. This approach may
|
||||||
} else {
|
// provide better performance due to less lock retention. The drawback is that
|
||||||
unsigned total_size = 0;
|
// the profiler will be a bit harder to read.
|
||||||
for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) {
|
void RunHandlerThreadPool::SetThreadWorkSources(
|
||||||
total_size += non_blocking_work_queues_[i]->queue.Size();
|
int tid, int start_request_idx, uint64 version,
|
||||||
}
|
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
|
||||||
return total_size;
|
mutex_lock l(thread_data_[tid].mu);
|
||||||
}
|
if (version > thread_data_[tid].new_version) {
|
||||||
|
thread_data_[tid].new_version = version;
|
||||||
|
} else {
|
||||||
|
// A newer version is already updated. No need to update.
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
thread_data_[tid].new_thread_work_sources->resize(0);
|
||||||
int64 GetTracemeId() { return traceme_id_.load(std::memory_order_relaxed); }
|
if (use_sub_thread_pool_) {
|
||||||
|
for (int i = 0; i < thread_work_sources.size(); ++i) {
|
||||||
void SetTracemeId(int64 value) { traceme_id_ = value; }
|
|
||||||
|
|
||||||
void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) {
|
|
||||||
{
|
|
||||||
tf_shared_lock lock(run_handler_waiter_mu_);
|
|
||||||
// Most of the request won't change sub pool for recomputation.
|
|
||||||
// Optimization for avoiding holding exclusive lock to reduce contention.
|
|
||||||
if (sub_thread_pool_waiter_ == waiter) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// If the current version is a newer version, no need to update.
|
|
||||||
if (version_ > version) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mutex_lock l(run_handler_waiter_mu_);
|
|
||||||
sub_thread_pool_waiter_ = waiter;
|
|
||||||
sub_thread_pool_waiter_mu_ = mutex;
|
|
||||||
version_ = version;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64 GetInflightTaskCount(bool is_blocking) {
|
|
||||||
std::atomic<int64>* counter =
|
|
||||||
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
|
|
||||||
return counter->load(std::memory_order_relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
void IncrementInflightTaskCount(bool is_blocking) {
|
|
||||||
std::atomic<int64>* counter =
|
|
||||||
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
|
|
||||||
counter->fetch_add(1, std::memory_order_relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
void DecrementInflightTaskCount(bool is_blocking) {
|
|
||||||
std::atomic<int64>* counter =
|
|
||||||
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
|
|
||||||
counter->fetch_sub(1, std::memory_order_relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned NonBlockingWorkShardingFactor() {
|
|
||||||
return non_blocking_work_sharding_factor_;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string ToString() {
|
|
||||||
return strings::StrCat("traceme_id = ", GetTracemeId(),
|
|
||||||
", inter queue size = ", TaskQueueSize(true),
|
|
||||||
", inter inflight = ", GetInflightTaskCount(true),
|
|
||||||
", intra queue size = ", TaskQueueSize(false),
|
|
||||||
", intra inflight = ", GetInflightTaskCount(false));
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
struct NonBlockingQueue {
|
|
||||||
mutex queue_op_mu;
|
|
||||||
char pad[128];
|
|
||||||
Queue queue;
|
|
||||||
};
|
|
||||||
|
|
||||||
int32 non_blocking_work_sharding_factor_;
|
|
||||||
Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_;
|
|
||||||
|
|
||||||
std::atomic<int64> blocking_inflight_;
|
|
||||||
std::atomic<int64> non_blocking_inflight_;
|
|
||||||
|
|
||||||
Queue blocking_work_queue_;
|
|
||||||
mutex blocking_queue_op_mu_;
|
|
||||||
char pad_[128];
|
|
||||||
mutex waiters_mu_;
|
|
||||||
Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_);
|
|
||||||
std::atomic<int64> traceme_id_;
|
|
||||||
|
|
||||||
mutex run_handler_waiter_mu_;
|
|
||||||
uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_);
|
|
||||||
mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_);
|
|
||||||
Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_);
|
|
||||||
};
|
|
||||||
|
|
||||||
class RunHandlerThreadPool {
|
|
||||||
public:
|
|
||||||
struct PerThread {
|
|
||||||
constexpr PerThread() : pool(nullptr), thread_id(-1) {}
|
|
||||||
RunHandlerThreadPool* pool; // Parent pool, or null for normal threads.
|
|
||||||
int thread_id; // Worker thread index in pool.
|
|
||||||
};
|
|
||||||
|
|
||||||
RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
|
|
||||||
Env* env, const ThreadOptions& thread_options,
|
|
||||||
const string& name,
|
|
||||||
Eigen::MaxSizeVector<mutex>* waiters_mu,
|
|
||||||
Eigen::MaxSizeVector<Waiter>* queue_waiters)
|
|
||||||
: num_threads_(num_blocking_threads + num_non_blocking_threads),
|
|
||||||
num_blocking_threads_(num_blocking_threads),
|
|
||||||
num_non_blocking_threads_(num_non_blocking_threads),
|
|
||||||
thread_data_(num_threads_),
|
|
||||||
env_(env, thread_options, name),
|
|
||||||
name_(name),
|
|
||||||
waiters_mu_(waiters_mu),
|
|
||||||
queue_waiters_(queue_waiters),
|
|
||||||
use_sub_thread_pool_(ParamFromEnvBoolWithDefault(
|
|
||||||
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
|
|
||||||
num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
|
|
||||||
"TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
|
|
||||||
std::vector<int>(
|
|
||||||
{num_blocking_threads / 2,
|
|
||||||
num_blocking_threads - num_blocking_threads / 2}))),
|
|
||||||
sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
|
|
||||||
"TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
|
|
||||||
std::vector<double>({0, 0.4}))),
|
|
||||||
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
|
|
||||||
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
|
|
||||||
std::vector<double>({0.4, 1}))) {
|
|
||||||
VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
|
|
||||||
<< num_blocking_threads_ << " blocking threads and "
|
|
||||||
<< num_non_blocking_threads_ << " non-blocking threads.";
|
|
||||||
}
|
|
||||||
|
|
||||||
~RunHandlerThreadPool() {
|
|
||||||
VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
|
|
||||||
|
|
||||||
cancelled_ = true;
|
|
||||||
for (size_t i = 0; i < thread_data_.size(); ++i) {
|
|
||||||
{
|
|
||||||
mutex_lock l(thread_data_[i].mu);
|
|
||||||
thread_data_[i].sources_not_empty.notify_all();
|
|
||||||
}
|
|
||||||
thread_data_[i].thread.reset();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Start() {
|
|
||||||
cancelled_ = false;
|
|
||||||
thread_data_.resize(num_threads_);
|
|
||||||
int num_blocking_threads = num_blocking_threads_;
|
|
||||||
for (int i = 0; i < num_threads_; i++) {
|
|
||||||
int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
|
|
||||||
for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
|
|
||||||
if (i < num_threads_in_sub_thread_pool_[j]) {
|
|
||||||
sub_thread_pool_id = j;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
|
|
||||||
thread_data_[i].thread.reset(
|
|
||||||
env_.CreateThread([this, i, num_blocking_threads]() {
|
|
||||||
WorkerLoop(i, i < num_blocking_threads);
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
|
|
||||||
std::function<void()> fn) {
|
|
||||||
Task t = env_.CreateTask(std::move(fn));
|
|
||||||
t = tws->EnqueueTask(std::move(t), is_blocking);
|
|
||||||
if (t.f) {
|
|
||||||
VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for "
|
|
||||||
<< tws->GetTracemeId();
|
|
||||||
env_.ExecuteTask(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set work queues from which the thread 'tid' can steal its work.
|
|
||||||
// The request with start_request_idx will be attempted first. Other requests
|
|
||||||
// will be attempted in FIFO order based on their arrival time.
|
|
||||||
|
|
||||||
// TODO(donglin) Change the task steal order to be round-robin such that if
|
|
||||||
// an attempt to steal task from request i failed, then attempt to steal task
|
|
||||||
// from the next request in terms of the arrival time. This approach may
|
|
||||||
// provide better performance due to less lock retention. The drawback is that
|
|
||||||
// the profiler will be a bit harder to read.
|
|
||||||
void SetThreadWorkSources(
|
|
||||||
int tid, int start_request_idx, uint64 version,
|
|
||||||
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
|
|
||||||
mutex_lock l(thread_data_[tid].mu);
|
|
||||||
if (version > thread_data_[tid].new_version) {
|
|
||||||
thread_data_[tid].new_version = version;
|
|
||||||
} else {
|
|
||||||
// A newer version is already updated. No need to update.
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
thread_data_[tid].new_thread_work_sources->resize(0);
|
|
||||||
|
|
||||||
if (use_sub_thread_pool_) {
|
|
||||||
for (int i = 0; i < thread_work_sources.size(); ++i) {
|
|
||||||
thread_data_[tid].new_thread_work_sources->emplace_back(
|
|
||||||
thread_work_sources[i]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
thread_data_[tid].new_thread_work_sources->emplace_back(
|
thread_data_[tid].new_thread_work_sources->emplace_back(
|
||||||
thread_work_sources[start_request_idx]);
|
thread_work_sources[i]);
|
||||||
// The number of shards for the queue. Threads in each shard will
|
}
|
||||||
// prioritize different thread_work_sources. Increase the number of shards
|
} else {
|
||||||
// could decrease the contention in the queue. For example, when
|
thread_data_[tid].new_thread_work_sources->emplace_back(
|
||||||
// num_shards == 1: thread_work_sources are ordered as start_request_idx,
|
thread_work_sources[start_request_idx]);
|
||||||
// 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2:
|
// The number of shards for the queue. Threads in each shard will
|
||||||
// thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
|
// prioritize different thread_work_sources. Increase the number of shards
|
||||||
// 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
|
// could decrease the contention in the queue. For example, when
|
||||||
// 4... for the other half of the threads.
|
// num_shards == 1: thread_work_sources are ordered as start_request_idx,
|
||||||
int num_shards =
|
// 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2:
|
||||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
|
// thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
|
||||||
int token = tid % num_shards;
|
// 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
|
||||||
for (int i = 0; i < num_shards; ++i) {
|
// 4... for the other half of the threads.
|
||||||
for (int j = token; j < thread_work_sources.size(); j += num_shards) {
|
int num_shards = ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
|
||||||
if (j != start_request_idx) {
|
int token = tid % num_shards;
|
||||||
thread_data_[tid].new_thread_work_sources->emplace_back(
|
for (int i = 0; i < num_shards; ++i) {
|
||||||
thread_work_sources[j]);
|
for (int j = token; j < thread_work_sources.size(); j += num_shards) {
|
||||||
}
|
if (j != start_request_idx) {
|
||||||
|
thread_data_[tid].new_thread_work_sources->emplace_back(
|
||||||
|
thread_work_sources[j]);
|
||||||
}
|
}
|
||||||
token = (token + 1) % num_shards;
|
|
||||||
}
|
}
|
||||||
thread_data_[tid].sources_not_empty.notify_all();
|
token = (token + 1) % num_shards;
|
||||||
}
|
}
|
||||||
|
thread_data_[tid].sources_not_empty.notify_all();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
PerThread* GetPerThread() {
|
RunHandlerThreadPool::PerThread* RunHandlerThreadPool::GetPerThread() {
|
||||||
thread_local PerThread per_thread_;
|
thread_local RunHandlerThreadPool::PerThread per_thread_;
|
||||||
PerThread* pt = &per_thread_;
|
RunHandlerThreadPool::PerThread* pt = &per_thread_;
|
||||||
return pt;
|
return pt;
|
||||||
|
}
|
||||||
|
|
||||||
|
int RunHandlerThreadPool::CurrentThreadId() const {
|
||||||
|
const PerThread* pt = const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
|
||||||
|
if (pt->pool == this) {
|
||||||
|
return pt->thread_id;
|
||||||
|
} else {
|
||||||
|
return -1;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int CurrentThreadId() const {
|
int RunHandlerThreadPool::NumThreads() const { return num_threads_; }
|
||||||
const PerThread* pt =
|
|
||||||
const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
|
|
||||||
if (pt->pool == this) {
|
|
||||||
return pt->thread_id;
|
|
||||||
} else {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int NumThreads() const { return num_threads_; }
|
int RunHandlerThreadPool::NumBlockingThreads() const {
|
||||||
|
return num_blocking_threads_;
|
||||||
|
}
|
||||||
|
|
||||||
int NumBlockingThreads() const { return num_blocking_threads_; }
|
int RunHandlerThreadPool::NumNonBlockingThreads() const {
|
||||||
|
return num_non_blocking_threads_;
|
||||||
|
}
|
||||||
|
|
||||||
int NumNonBlockingThreads() const { return num_non_blocking_threads_; }
|
RunHandlerThreadPool::ThreadData::ThreadData()
|
||||||
|
: new_version(0),
|
||||||
void WorkerLoop(int thread_id, bool may_steal_blocking_work);
|
current_index(0),
|
||||||
|
new_thread_work_sources(
|
||||||
// Search tasks from Requets range searching_range_start to
|
new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
|
||||||
// searching_range_end. If there is no tasks in the search range and
|
ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
||||||
// may_steal_blocking_work is true, then search from all requests.
|
kMaxConcurrentHandlers)))),
|
||||||
Task FindTask(
|
current_version(0),
|
||||||
int searching_range_start, int searching_range_end, int thread_id,
|
current_thread_work_sources(
|
||||||
int sub_thread_pool_id, int max_blocking_inflight,
|
new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
|
||||||
bool may_steal_blocking_work,
|
ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
||||||
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
|
kMaxConcurrentHandlers)))) {}
|
||||||
bool* task_from_blocking_queue, ThreadWorkSource** tws);
|
|
||||||
|
|
||||||
void WaitForWork(bool is_blocking, int thread_id,
|
|
||||||
int32 max_blocking_inflight);
|
|
||||||
|
|
||||||
void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);
|
|
||||||
|
|
||||||
private:
|
|
||||||
struct ThreadData {
|
|
||||||
ThreadData()
|
|
||||||
: new_version(0),
|
|
||||||
current_index(0),
|
|
||||||
new_thread_work_sources(new Eigen::MaxSizeVector<ThreadWorkSource*>(
|
|
||||||
static_cast<int32>(ParamFromEnvWithDefault(
|
|
||||||
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
|
||||||
kMaxConcurrentHandlers)))),
|
|
||||||
current_version(0),
|
|
||||||
current_thread_work_sources(
|
|
||||||
new Eigen::MaxSizeVector<ThreadWorkSource*>(
|
|
||||||
static_cast<int32>(ParamFromEnvWithDefault(
|
|
||||||
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
|
||||||
kMaxConcurrentHandlers)))) {}
|
|
||||||
mutex mu;
|
|
||||||
uint64 new_version;
|
|
||||||
condition_variable sources_not_empty;
|
|
||||||
std::unique_ptr<Thread> thread;
|
|
||||||
int current_index;
|
|
||||||
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
|
|
||||||
new_thread_work_sources TF_GUARDED_BY(mu);
|
|
||||||
|
|
||||||
uint64 current_version;
|
|
||||||
// Should only be accessed by one thread.
|
|
||||||
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
|
|
||||||
current_thread_work_sources;
|
|
||||||
|
|
||||||
int sub_thread_pool_id;
|
|
||||||
};
|
|
||||||
|
|
||||||
const int num_threads_;
|
|
||||||
const int num_blocking_threads_;
|
|
||||||
const int num_non_blocking_threads_;
|
|
||||||
Eigen::MaxSizeVector<ThreadData> thread_data_;
|
|
||||||
RunHandlerEnvironment env_;
|
|
||||||
std::atomic<bool> cancelled_;
|
|
||||||
string name_;
|
|
||||||
Eigen::MaxSizeVector<mutex>* waiters_mu_;
|
|
||||||
Eigen::MaxSizeVector<Waiter>* queue_waiters_;
|
|
||||||
|
|
||||||
bool use_sub_thread_pool_;
|
|
||||||
std::vector<int> num_threads_in_sub_thread_pool_;
|
|
||||||
|
|
||||||
// Threads in each sub thread pool will search tasks from the given
|
|
||||||
// start_request_percentage to end_request_percentage in a round robin
|
|
||||||
// fashion.
|
|
||||||
std::vector<double> sub_thread_pool_start_request_percentage_;
|
|
||||||
std::vector<double> sub_thread_pool_end_request_percentage_;
|
|
||||||
};
|
|
||||||
|
|
||||||
Task RunHandlerThreadPool::FindTask(
|
Task RunHandlerThreadPool::FindTask(
|
||||||
int searching_range_start, int searching_range_end, int thread_id,
|
int searching_range_start, int searching_range_end, int thread_id,
|
||||||
@ -597,10 +487,9 @@ Task RunHandlerThreadPool::FindTask(
|
|||||||
int current_index = thread_data_[thread_id].current_index;
|
int current_index = thread_data_[thread_id].current_index;
|
||||||
*task_from_blocking_queue = false;
|
*task_from_blocking_queue = false;
|
||||||
|
|
||||||
// TODO(chaox): Change the search algorithm from round robin to random
|
|
||||||
// walk.
|
|
||||||
for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
|
for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
|
||||||
if (current_index >= searching_range_end) {
|
if (current_index >= searching_range_end ||
|
||||||
|
current_index < searching_range_start) {
|
||||||
current_index = searching_range_start;
|
current_index = searching_range_start;
|
||||||
}
|
}
|
||||||
*tws = thread_work_sources[current_index];
|
*tws = thread_work_sources[current_index];
|
||||||
@ -821,7 +710,7 @@ void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id,
|
|||||||
tws->WaitForWork(kMaxSleepMicros);
|
tws->WaitForWork(kMaxSleepMicros);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace internal
|
||||||
|
|
||||||
// Contains the concrete implementation of the RunHandler.
|
// Contains the concrete implementation of the RunHandler.
|
||||||
// Externally visible RunHandler class simply forwards the work to this one.
|
// Externally visible RunHandler class simply forwards the work to this one.
|
||||||
@ -847,7 +736,7 @@ class RunHandler::Impl {
|
|||||||
|
|
||||||
RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
|
RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
|
||||||
|
|
||||||
ThreadWorkSource* tws() { return &tws_; }
|
internal::ThreadWorkSource* tws() { return &tws_; }
|
||||||
|
|
||||||
int64 priority() { return options_.priority(); }
|
int64 priority() { return options_.priority(); }
|
||||||
|
|
||||||
@ -869,7 +758,7 @@ class RunHandler::Impl {
|
|||||||
uint64 start_time_us_;
|
uint64 start_time_us_;
|
||||||
int64 step_id_;
|
int64 step_id_;
|
||||||
std::unique_ptr<thread::ThreadPoolInterface> thread_pool_interface_;
|
std::unique_ptr<thread::ThreadPoolInterface> thread_pool_interface_;
|
||||||
ThreadWorkSource tws_;
|
internal::ThreadWorkSource tws_;
|
||||||
RunOptions::Experimental::RunHandlerPoolOptions options_;
|
RunOptions::Experimental::RunHandlerPoolOptions options_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -885,7 +774,7 @@ class RunHandlerPool::Impl {
|
|||||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
|
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
|
||||||
queue_waiters_(
|
queue_waiters_(
|
||||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
|
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
|
||||||
run_handler_thread_pool_(new RunHandlerThreadPool(
|
run_handler_thread_pool_(new internal::RunHandlerThreadPool(
|
||||||
num_inter_op_threads, num_intra_op_threads, Env::Default(),
|
num_inter_op_threads, num_intra_op_threads, Env::Default(),
|
||||||
ThreadOptions(), "tf_run_handler_pool", &waiters_mu_,
|
ThreadOptions(), "tf_run_handler_pool", &waiters_mu_,
|
||||||
&queue_waiters_)),
|
&queue_waiters_)),
|
||||||
@ -924,7 +813,7 @@ class RunHandlerPool::Impl {
|
|||||||
run_handler_thread_pool_.reset();
|
run_handler_thread_pool_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
RunHandlerThreadPool* run_handler_thread_pool() {
|
internal::RunHandlerThreadPool* run_handler_thread_pool() {
|
||||||
return run_handler_thread_pool_.get();
|
return run_handler_thread_pool_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -936,10 +825,11 @@ class RunHandlerPool::Impl {
|
|||||||
int64 step_id, int64 timeout_in_ms,
|
int64 step_id, int64 timeout_in_ms,
|
||||||
const RunOptions::Experimental::RunHandlerPoolOptions& options)
|
const RunOptions::Experimental::RunHandlerPoolOptions& options)
|
||||||
TF_LOCKS_EXCLUDED(mu_) {
|
TF_LOCKS_EXCLUDED(mu_) {
|
||||||
thread_local std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
|
thread_local std::unique_ptr<
|
||||||
|
Eigen::MaxSizeVector<internal::ThreadWorkSource*>>
|
||||||
thread_work_sources =
|
thread_work_sources =
|
||||||
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>(
|
std::unique_ptr<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>(
|
||||||
new Eigen::MaxSizeVector<ThreadWorkSource*>(
|
new Eigen::MaxSizeVector<internal::ThreadWorkSource*>(
|
||||||
static_cast<int32>(ParamFromEnvWithDefault(
|
static_cast<int32>(ParamFromEnvWithDefault(
|
||||||
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
||||||
kMaxConcurrentHandlers))));
|
kMaxConcurrentHandlers))));
|
||||||
@ -1035,7 +925,8 @@ class RunHandlerPool::Impl {
|
|||||||
private:
|
private:
|
||||||
void RecomputePoolStats(
|
void RecomputePoolStats(
|
||||||
int num_active_requests, uint64 version,
|
int num_active_requests, uint64 version,
|
||||||
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources);
|
const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
|
||||||
|
thread_work_sources);
|
||||||
|
|
||||||
void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
@ -1046,9 +937,9 @@ class RunHandlerPool::Impl {
|
|||||||
const int max_handlers_;
|
const int max_handlers_;
|
||||||
|
|
||||||
Eigen::MaxSizeVector<mutex> waiters_mu_;
|
Eigen::MaxSizeVector<mutex> waiters_mu_;
|
||||||
Eigen::MaxSizeVector<Waiter> queue_waiters_;
|
Eigen::MaxSizeVector<internal::Waiter> queue_waiters_;
|
||||||
|
|
||||||
std::unique_ptr<RunHandlerThreadPool> run_handler_thread_pool_;
|
std::unique_ptr<internal::RunHandlerThreadPool> run_handler_thread_pool_;
|
||||||
// Thread compatible part used only by lock under RunHandlerPool.
|
// Thread compatible part used only by lock under RunHandlerPool.
|
||||||
// Handlers are sorted by start time.
|
// Handlers are sorted by start time.
|
||||||
// TODO(azaks): sort by the remaining latency budget.
|
// TODO(azaks): sort by the remaining latency budget.
|
||||||
@ -1070,7 +961,8 @@ class RunHandlerPool::Impl {
|
|||||||
|
|
||||||
void RunHandlerPool::Impl::RecomputePoolStats(
|
void RunHandlerPool::Impl::RecomputePoolStats(
|
||||||
int num_active_requests, uint64 version,
|
int num_active_requests, uint64 version,
|
||||||
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
|
const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
|
||||||
|
thread_work_sources) {
|
||||||
if (num_active_requests == 0) return;
|
if (num_active_requests == 0) return;
|
||||||
|
|
||||||
int sub_thread_pool_id = 0;
|
int sub_thread_pool_id = 0;
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/lib/histogram/histogram.h"
|
#include "tensorflow/core/lib/histogram/histogram.h"
|
||||||
|
#include "tensorflow/core/platform/context.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/protobuf/config.pb.h"
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
@ -106,6 +107,208 @@ class RunHandler {
|
|||||||
Impl* impl_; // NOT OWNED.
|
Impl* impl_; // NOT OWNED.
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// TODO(azaks): Refactor with thread:ThreadPool
|
||||||
|
class RunHandlerEnvironment {
|
||||||
|
typedef Thread EnvThread;
|
||||||
|
struct TaskImpl {
|
||||||
|
std::function<void()> f;
|
||||||
|
Context context;
|
||||||
|
uint64 trace_id;
|
||||||
|
};
|
||||||
|
Env* const env_;
|
||||||
|
const ThreadOptions thread_options_;
|
||||||
|
const string name_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
struct Task {
|
||||||
|
std::unique_ptr<TaskImpl> f;
|
||||||
|
};
|
||||||
|
|
||||||
|
RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options,
|
||||||
|
const string& name);
|
||||||
|
|
||||||
|
EnvThread* CreateThread(std::function<void()> f);
|
||||||
|
|
||||||
|
Task CreateTask(std::function<void()> f);
|
||||||
|
|
||||||
|
void ExecuteTask(const Task& t);
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef typename RunHandlerEnvironment::Task Task;
|
||||||
|
typedef Eigen::RunQueue<Task, 1024> Queue;
|
||||||
|
|
||||||
|
// To reduce cache misses, we use a doubly-linked list of Waiter structs and
|
||||||
|
// queue them in LIFO order rather than the FIFO order used by a single
|
||||||
|
// condition variable.
|
||||||
|
struct Waiter {
|
||||||
|
Waiter() {
|
||||||
|
next = this;
|
||||||
|
prev = this;
|
||||||
|
}
|
||||||
|
condition_variable cv;
|
||||||
|
mutex mu;
|
||||||
|
Waiter* next;
|
||||||
|
Waiter* prev;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ThreadWorkSource {
|
||||||
|
public:
|
||||||
|
ThreadWorkSource();
|
||||||
|
|
||||||
|
~ThreadWorkSource();
|
||||||
|
|
||||||
|
Task EnqueueTask(Task t, bool is_blocking);
|
||||||
|
|
||||||
|
Task PopBlockingTask();
|
||||||
|
|
||||||
|
Task PopNonBlockingTask(int start_index, bool search_from_all_queue);
|
||||||
|
|
||||||
|
void WaitForWork(int max_sleep_micros);
|
||||||
|
|
||||||
|
int TaskQueueSize(bool is_blocking);
|
||||||
|
|
||||||
|
int64 GetTracemeId();
|
||||||
|
|
||||||
|
void SetTracemeId(int64 value);
|
||||||
|
|
||||||
|
void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex);
|
||||||
|
|
||||||
|
int64 GetInflightTaskCount(bool is_blocking);
|
||||||
|
|
||||||
|
void IncrementInflightTaskCount(bool is_blocking);
|
||||||
|
|
||||||
|
void DecrementInflightTaskCount(bool is_blocking);
|
||||||
|
|
||||||
|
unsigned NonBlockingWorkShardingFactor();
|
||||||
|
|
||||||
|
std::string ToString();
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct NonBlockingQueue {
|
||||||
|
mutex queue_op_mu;
|
||||||
|
char pad[128];
|
||||||
|
Queue queue;
|
||||||
|
};
|
||||||
|
|
||||||
|
int32 non_blocking_work_sharding_factor_;
|
||||||
|
Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_;
|
||||||
|
|
||||||
|
std::atomic<int64> blocking_inflight_;
|
||||||
|
std::atomic<int64> non_blocking_inflight_;
|
||||||
|
|
||||||
|
Queue blocking_work_queue_;
|
||||||
|
mutex blocking_queue_op_mu_;
|
||||||
|
char pad_[128];
|
||||||
|
mutex waiters_mu_;
|
||||||
|
Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_);
|
||||||
|
std::atomic<int64> traceme_id_;
|
||||||
|
|
||||||
|
mutex run_handler_waiter_mu_;
|
||||||
|
uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_);
|
||||||
|
mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_);
|
||||||
|
Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_);
|
||||||
|
};
|
||||||
|
|
||||||
|
class RunHandlerThreadPool {
|
||||||
|
public:
|
||||||
|
struct PerThread {
|
||||||
|
constexpr PerThread() : pool(nullptr), thread_id(-1) {}
|
||||||
|
RunHandlerThreadPool* pool; // Parent pool, or null for normal threads.
|
||||||
|
int thread_id; // Worker thread index in pool.
|
||||||
|
};
|
||||||
|
|
||||||
|
RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
|
||||||
|
Env* env, const ThreadOptions& thread_options,
|
||||||
|
const string& name,
|
||||||
|
Eigen::MaxSizeVector<mutex>* waiters_mu,
|
||||||
|
Eigen::MaxSizeVector<Waiter>* queue_waiters);
|
||||||
|
|
||||||
|
~RunHandlerThreadPool();
|
||||||
|
|
||||||
|
void Start();
|
||||||
|
|
||||||
|
void StartOneThreadForTesting();
|
||||||
|
|
||||||
|
void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
|
||||||
|
std::function<void()> fn);
|
||||||
|
|
||||||
|
// Set work queues from which the thread 'tid' can steal its work.
|
||||||
|
// The request with start_request_idx will be attempted first. Other requests
|
||||||
|
// will be attempted in FIFO order based on their arrival time.
|
||||||
|
void SetThreadWorkSources(
|
||||||
|
int tid, int start_request_idx, uint64 version,
|
||||||
|
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources);
|
||||||
|
|
||||||
|
PerThread* GetPerThread();
|
||||||
|
|
||||||
|
int CurrentThreadId() const;
|
||||||
|
|
||||||
|
int NumThreads() const;
|
||||||
|
|
||||||
|
int NumBlockingThreads() const;
|
||||||
|
|
||||||
|
int NumNonBlockingThreads() const;
|
||||||
|
|
||||||
|
void WorkerLoop(int thread_id, bool may_steal_blocking_work);
|
||||||
|
|
||||||
|
// Search tasks from Requets range searching_range_start to
|
||||||
|
// searching_range_end. If there is no tasks in the search range and
|
||||||
|
// may_steal_blocking_work is true, then search from all requests.
|
||||||
|
Task FindTask(
|
||||||
|
int searching_range_start, int searching_range_end, int thread_id,
|
||||||
|
int sub_thread_pool_id, int max_blocking_inflight,
|
||||||
|
bool may_steal_blocking_work,
|
||||||
|
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
|
||||||
|
bool* task_from_blocking_queue, ThreadWorkSource** tws);
|
||||||
|
|
||||||
|
void WaitForWork(bool is_blocking, int thread_id,
|
||||||
|
int32 max_blocking_inflight);
|
||||||
|
|
||||||
|
void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct ThreadData {
|
||||||
|
ThreadData();
|
||||||
|
mutex mu;
|
||||||
|
uint64 new_version;
|
||||||
|
condition_variable sources_not_empty;
|
||||||
|
std::unique_ptr<Thread> thread;
|
||||||
|
int current_index;
|
||||||
|
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
|
||||||
|
new_thread_work_sources TF_GUARDED_BY(mu);
|
||||||
|
|
||||||
|
uint64 current_version;
|
||||||
|
// Should only be accessed by one thread.
|
||||||
|
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
|
||||||
|
current_thread_work_sources;
|
||||||
|
|
||||||
|
int sub_thread_pool_id;
|
||||||
|
};
|
||||||
|
|
||||||
|
const int num_threads_;
|
||||||
|
const int num_blocking_threads_;
|
||||||
|
const int num_non_blocking_threads_;
|
||||||
|
Eigen::MaxSizeVector<ThreadData> thread_data_;
|
||||||
|
internal::RunHandlerEnvironment env_;
|
||||||
|
std::atomic<bool> cancelled_;
|
||||||
|
string name_;
|
||||||
|
Eigen::MaxSizeVector<mutex>* waiters_mu_;
|
||||||
|
Eigen::MaxSizeVector<Waiter>* queue_waiters_;
|
||||||
|
|
||||||
|
bool use_sub_thread_pool_;
|
||||||
|
std::vector<int> num_threads_in_sub_thread_pool_;
|
||||||
|
|
||||||
|
// Threads in each sub thread pool will search tasks from the given
|
||||||
|
// start_request_percentage to end_request_percentage in a round robin
|
||||||
|
// fashion.
|
||||||
|
std::vector<double> sub_thread_pool_start_request_percentage_;
|
||||||
|
std::vector<double> sub_thread_pool_end_request_percentage_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace internal
|
||||||
|
|
||||||
} // end namespace tensorflow.
|
} // end namespace tensorflow.
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
|
#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
|
||||||
|
@ -113,6 +113,476 @@ TEST(RunHandlerUtilTest, PrioritySchedulingTest) {
|
|||||||
EXPECT_EQ(sorted_active_list[3], 1);
|
EXPECT_EQ(sorted_active_list[3], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(RunHandlerThreadPool, EnqueueTask) {
|
||||||
|
Eigen::MaxSizeVector<mutex> waiters_mu(2);
|
||||||
|
waiters_mu.resize(2);
|
||||||
|
Eigen::MaxSizeVector<internal::Waiter> waiters(2);
|
||||||
|
waiters.resize(2);
|
||||||
|
internal::RunHandlerThreadPool run_handler_thread_pool(
|
||||||
|
/*num_blocking_threads=*/0, /*num_non_blocking_threads=*/0,
|
||||||
|
Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
|
||||||
|
&waiters);
|
||||||
|
internal::ThreadWorkSource tws;
|
||||||
|
|
||||||
|
int result = 0;
|
||||||
|
std::function<void()> fn = [&result] { result = 1; };
|
||||||
|
std::function<void()> fn2 = [&result] { result = 2; };
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/true, fn);
|
||||||
|
EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/true), 1);
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/true, fn2);
|
||||||
|
EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/true), 2);
|
||||||
|
tws.PopBlockingTask().f->f();
|
||||||
|
EXPECT_EQ(result, 1);
|
||||||
|
tws.PopBlockingTask().f->f();
|
||||||
|
EXPECT_EQ(result, 2);
|
||||||
|
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/false, fn);
|
||||||
|
EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/false), 1);
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/false, fn2);
|
||||||
|
EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/false), 2);
|
||||||
|
tws.PopNonBlockingTask(0, true).f->f();
|
||||||
|
EXPECT_EQ(result, 1);
|
||||||
|
tws.PopNonBlockingTask(0, true).f->f();
|
||||||
|
EXPECT_EQ(result, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RunHandlerThreadPool, FindTask) {
|
||||||
|
Eigen::MaxSizeVector<mutex> waiters_mu(2);
|
||||||
|
waiters_mu.resize(2);
|
||||||
|
Eigen::MaxSizeVector<internal::Waiter> waiters(2);
|
||||||
|
waiters.resize(2);
|
||||||
|
internal::RunHandlerThreadPool run_handler_thread_pool(
|
||||||
|
/*num_blocking_threads=*/1, /*num_non_blocking_threads=*/0,
|
||||||
|
Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
|
||||||
|
&waiters);
|
||||||
|
|
||||||
|
Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(5);
|
||||||
|
thread_work_sources.resize(5);
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
thread_work_sources[i] = new internal::ThreadWorkSource();
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// The thread should search the task following round robin fashion.
|
||||||
|
int result = -1;
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 2; });
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 2; });
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 3; });
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 3; });
|
||||||
|
|
||||||
|
const auto find_blocking_task_from_all_handlers =
|
||||||
|
[&](bool* task_from_blocking_queue, internal::Task* t) {
|
||||||
|
internal::ThreadWorkSource* tws;
|
||||||
|
*t = run_handler_thread_pool.FindTask(
|
||||||
|
/*searching_range_start=*/0, /*searching_range_end=*/5,
|
||||||
|
/*thread_id=*/0,
|
||||||
|
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||||
|
/*may_steal_blocking_work=*/true, thread_work_sources,
|
||||||
|
task_from_blocking_queue, &tws);
|
||||||
|
};
|
||||||
|
bool task_from_blocking_queue;
|
||||||
|
internal::Task t;
|
||||||
|
find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, true);
|
||||||
|
t.f->f();
|
||||||
|
EXPECT_EQ(result, 2);
|
||||||
|
|
||||||
|
find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, true);
|
||||||
|
t.f->f();
|
||||||
|
EXPECT_EQ(result, 3);
|
||||||
|
|
||||||
|
find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, true);
|
||||||
|
t.f->f();
|
||||||
|
EXPECT_EQ(result, 2);
|
||||||
|
|
||||||
|
find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, true);
|
||||||
|
t.f->f();
|
||||||
|
EXPECT_EQ(result, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Task out of searching range cannot be found.
|
||||||
|
int result = -1;
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 3; });
|
||||||
|
|
||||||
|
const auto find_blocking_task_from_range =
|
||||||
|
[&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
|
||||||
|
int range_end) {
|
||||||
|
internal::ThreadWorkSource* tws;
|
||||||
|
*t = run_handler_thread_pool.FindTask(
|
||||||
|
range_start, range_end,
|
||||||
|
/*thread_id=*/0,
|
||||||
|
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||||
|
/*may_steal_blocking_work=*/true, thread_work_sources,
|
||||||
|
task_from_blocking_queue, &tws);
|
||||||
|
};
|
||||||
|
|
||||||
|
bool task_from_blocking_queue;
|
||||||
|
internal::Task t;
|
||||||
|
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 3);
|
||||||
|
EXPECT_EQ(t.f, nullptr);
|
||||||
|
|
||||||
|
// Clean up the queue.
|
||||||
|
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// The thread should search from start range if the currrent index is
|
||||||
|
// smaller.
|
||||||
|
int result = -1;
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 2; });
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 3; });
|
||||||
|
|
||||||
|
const auto find_blocking_task_from_range =
|
||||||
|
[&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
|
||||||
|
int range_end) {
|
||||||
|
internal::ThreadWorkSource* tws;
|
||||||
|
*t = run_handler_thread_pool.FindTask(
|
||||||
|
range_start, range_end,
|
||||||
|
/*thread_id=*/0,
|
||||||
|
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||||
|
/*may_steal_blocking_work=*/true, thread_work_sources,
|
||||||
|
task_from_blocking_queue, &tws);
|
||||||
|
};
|
||||||
|
bool task_from_blocking_queue;
|
||||||
|
internal::Task t;
|
||||||
|
find_blocking_task_from_range(&task_from_blocking_queue, &t, 3, 5);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, true);
|
||||||
|
t.f->f();
|
||||||
|
EXPECT_EQ(result, 3);
|
||||||
|
|
||||||
|
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, true);
|
||||||
|
t.f->f();
|
||||||
|
EXPECT_EQ(result, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// The thread should search within the range even if the current index
|
||||||
|
// is larger than searching_range_end;
|
||||||
|
int result = -1;
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 2; });
|
||||||
|
|
||||||
|
const auto find_blocking_task_from_range =
|
||||||
|
[&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
|
||||||
|
int range_end) {
|
||||||
|
internal::ThreadWorkSource* tws;
|
||||||
|
*t = run_handler_thread_pool.FindTask(
|
||||||
|
range_start, range_end,
|
||||||
|
/*thread_id=*/0,
|
||||||
|
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||||
|
/*may_steal_blocking_work=*/true, thread_work_sources,
|
||||||
|
task_from_blocking_queue, &tws);
|
||||||
|
};
|
||||||
|
bool task_from_blocking_queue;
|
||||||
|
// Make the current index to be 3.
|
||||||
|
internal::Task t;
|
||||||
|
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, true);
|
||||||
|
t.f->f();
|
||||||
|
EXPECT_EQ(result, 2);
|
||||||
|
|
||||||
|
// Search in a smaller range.
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 2; });
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 3; });
|
||||||
|
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 3);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, true);
|
||||||
|
t.f->f();
|
||||||
|
EXPECT_EQ(result, 2);
|
||||||
|
|
||||||
|
// Clean up the queue.
|
||||||
|
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, true);
|
||||||
|
t.f->f();
|
||||||
|
EXPECT_EQ(result, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// We prefer blocking task for blocking threads.
|
||||||
|
int result = -1;
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||||
|
/*is_blocking=*/false,
|
||||||
|
[&result] { result = 2; });
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 2; });
|
||||||
|
const auto blocking_thread_find_task_from_all_handler =
|
||||||
|
[&](bool* task_from_blocking_queue, internal::Task* t) {
|
||||||
|
internal::ThreadWorkSource* tws;
|
||||||
|
*t = run_handler_thread_pool.FindTask(
|
||||||
|
/*searching_range_start=*/0, /*searching_range_end=*/5,
|
||||||
|
/*thread_id=*/0,
|
||||||
|
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||||
|
/*may_steal_blocking_work=*/true, thread_work_sources,
|
||||||
|
task_from_blocking_queue, &tws);
|
||||||
|
};
|
||||||
|
bool task_from_blocking_queue;
|
||||||
|
internal::Task t;
|
||||||
|
blocking_thread_find_task_from_all_handler(&task_from_blocking_queue, &t);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, true);
|
||||||
|
t.f->f();
|
||||||
|
EXPECT_EQ(result, 2);
|
||||||
|
|
||||||
|
blocking_thread_find_task_from_all_handler(&task_from_blocking_queue, &t);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, false);
|
||||||
|
t.f->f();
|
||||||
|
EXPECT_EQ(result, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Nonblocking threads can only pick up non-blocking task.
|
||||||
|
int result = -1;
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||||
|
/*is_blocking=*/false,
|
||||||
|
[&result] { result = 2; });
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 2; });
|
||||||
|
|
||||||
|
const auto find_task_from_all_handler = [&](bool* task_from_blocking_queue,
|
||||||
|
internal::Task* t,
|
||||||
|
bool is_blocking_thread) {
|
||||||
|
internal::ThreadWorkSource* tws;
|
||||||
|
*t = run_handler_thread_pool.FindTask(
|
||||||
|
/*searching_range_start=*/0, /*searching_range_end=*/5,
|
||||||
|
/*thread_id=*/0,
|
||||||
|
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||||
|
is_blocking_thread, thread_work_sources, task_from_blocking_queue,
|
||||||
|
&tws);
|
||||||
|
};
|
||||||
|
bool task_from_blocking_queue;
|
||||||
|
internal::Task t;
|
||||||
|
find_task_from_all_handler(&task_from_blocking_queue, &t,
|
||||||
|
/*is_blocking_thread=*/false);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, false);
|
||||||
|
t.f->f();
|
||||||
|
EXPECT_EQ(result, 2);
|
||||||
|
|
||||||
|
find_task_from_all_handler(&task_from_blocking_queue, &t,
|
||||||
|
/*is_blocking_thread=*/false);
|
||||||
|
EXPECT_EQ(t.f, nullptr);
|
||||||
|
|
||||||
|
// Clean up the queue.
|
||||||
|
find_task_from_all_handler(&task_from_blocking_queue, &t,
|
||||||
|
/*is_blocking_thread=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// There is a limit for max_blocking_inflight requests.
|
||||||
|
int result = -1;
|
||||||
|
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||||
|
/*is_blocking=*/true,
|
||||||
|
[&result] { result = 2; });
|
||||||
|
|
||||||
|
const auto find_task_from_all_handler = [&](bool* task_from_blocking_queue,
|
||||||
|
internal::Task* t,
|
||||||
|
bool is_blocking_thread) {
|
||||||
|
internal::ThreadWorkSource* tws;
|
||||||
|
*t = run_handler_thread_pool.FindTask(
|
||||||
|
/*searching_range_start=*/0, /*searching_range_end=*/5,
|
||||||
|
/*thread_id=*/0,
|
||||||
|
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||||
|
is_blocking_thread, thread_work_sources, task_from_blocking_queue,
|
||||||
|
&tws);
|
||||||
|
};
|
||||||
|
|
||||||
|
bool task_from_blocking_queue;
|
||||||
|
internal::Task t;
|
||||||
|
find_task_from_all_handler(&task_from_blocking_queue, &t,
|
||||||
|
/*is_blocking_thread=*/false);
|
||||||
|
EXPECT_EQ(task_from_blocking_queue, false);
|
||||||
|
EXPECT_EQ(t.f, nullptr);
|
||||||
|
|
||||||
|
// Clean up the queue.
|
||||||
|
find_task_from_all_handler(&task_from_blocking_queue, &t,
|
||||||
|
/*is_blocking_thread=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
delete thread_work_sources[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RunHandlerThreadPool, RoundRobinExecution) {
|
||||||
|
// Set up environment for 1 sub thread pool.
|
||||||
|
setenv("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", "true", true);
|
||||||
|
setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "1", true);
|
||||||
|
setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", "0", true);
|
||||||
|
setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", "1", true);
|
||||||
|
|
||||||
|
Eigen::MaxSizeVector<mutex> waiters_mu(1);
|
||||||
|
waiters_mu.resize(1);
|
||||||
|
Eigen::MaxSizeVector<internal::Waiter> waiters(1);
|
||||||
|
waiters.resize(1);
|
||||||
|
internal::RunHandlerThreadPool* run_handler_thread_pool =
|
||||||
|
new internal::RunHandlerThreadPool(
|
||||||
|
/*num_blocking_threads=*/1, /*num_non_blocking_threads=*/0,
|
||||||
|
Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
|
||||||
|
&waiters);
|
||||||
|
Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(3);
|
||||||
|
thread_work_sources.resize(3);
|
||||||
|
internal::ThreadWorkSource tws[3];
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
tws[i].SetWaiter(1, &waiters[0], &waiters_mu[0]);
|
||||||
|
thread_work_sources[i] = &tws[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int result = 0;
|
||||||
|
mutex mu;
|
||||||
|
bool ok_to_execute = false;
|
||||||
|
bool ok_to_validate = false;
|
||||||
|
condition_variable function_start;
|
||||||
|
condition_variable function_end;
|
||||||
|
std::vector<std::function<void()>> fns;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
fns.push_back([&result, &mu, &function_start, &function_end, &ok_to_execute,
|
||||||
|
&ok_to_validate, i] {
|
||||||
|
mutex_lock l(mu);
|
||||||
|
while (!ok_to_execute) {
|
||||||
|
function_start.wait(l);
|
||||||
|
}
|
||||||
|
result = i;
|
||||||
|
ok_to_execute = false;
|
||||||
|
ok_to_validate = true;
|
||||||
|
function_end.notify_one();
|
||||||
|
});
|
||||||
|
run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
|
||||||
|
fns[i]);
|
||||||
|
run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
|
||||||
|
fns[i]);
|
||||||
|
}
|
||||||
|
run_handler_thread_pool->Start();
|
||||||
|
run_handler_thread_pool->SetThreadWorkSources(
|
||||||
|
/*tid=*/0, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
|
||||||
|
|
||||||
|
// Validate the execution should be roundrobin.
|
||||||
|
mutex_lock l(mu);
|
||||||
|
for (int round = 0; round < 2; ++round) {
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
ok_to_execute = true;
|
||||||
|
function_start.notify_one();
|
||||||
|
while (!ok_to_validate) {
|
||||||
|
function_end.wait(l);
|
||||||
|
}
|
||||||
|
ok_to_validate = false;
|
||||||
|
EXPECT_EQ(result, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
delete run_handler_thread_pool;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RunHandlerThreadPool, MultipleSubThreadPool) {
|
||||||
|
// Set up environment for 2 sub thread pools.
|
||||||
|
setenv("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", "true", true);
|
||||||
|
setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "2", true);
|
||||||
|
setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", "0,0.5",
|
||||||
|
true);
|
||||||
|
setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", "0.5,1",
|
||||||
|
true);
|
||||||
|
|
||||||
|
Eigen::MaxSizeVector<mutex> waiters_mu(2);
|
||||||
|
waiters_mu.resize(2);
|
||||||
|
Eigen::MaxSizeVector<internal::Waiter> waiters(2);
|
||||||
|
waiters.resize(2);
|
||||||
|
internal::RunHandlerThreadPool* run_handler_thread_pool =
|
||||||
|
new internal::RunHandlerThreadPool(
|
||||||
|
/*num_blocking_threads=*/2, /*num_non_blocking_threads=*/0,
|
||||||
|
Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
|
||||||
|
&waiters);
|
||||||
|
Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(4);
|
||||||
|
thread_work_sources.resize(4);
|
||||||
|
internal::ThreadWorkSource tws[4];
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
tws[i].SetWaiter(1, &waiters[i / 2], &waiters_mu[i / 2]);
|
||||||
|
thread_work_sources[i] = &tws[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int result = 0;
|
||||||
|
mutex mu;
|
||||||
|
bool ok_to_execute = false;
|
||||||
|
bool ok_to_validate = false;
|
||||||
|
condition_variable function_start;
|
||||||
|
condition_variable function_end;
|
||||||
|
|
||||||
|
std::vector<std::function<void()>> fns;
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
fns.push_back([&result, &mu, &function_start, &function_end, &ok_to_execute,
|
||||||
|
&ok_to_validate, i] {
|
||||||
|
mutex_lock l(mu);
|
||||||
|
while (!ok_to_execute) {
|
||||||
|
function_start.wait(l);
|
||||||
|
}
|
||||||
|
result = i;
|
||||||
|
ok_to_execute = false;
|
||||||
|
ok_to_validate = true;
|
||||||
|
function_end.notify_one();
|
||||||
|
});
|
||||||
|
run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
|
||||||
|
fns[i]);
|
||||||
|
run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
|
||||||
|
fns[i]);
|
||||||
|
}
|
||||||
|
run_handler_thread_pool->StartOneThreadForTesting();
|
||||||
|
run_handler_thread_pool->SetThreadWorkSources(
|
||||||
|
/*tid=*/0, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
|
||||||
|
run_handler_thread_pool->SetThreadWorkSources(
|
||||||
|
/*tid=*/1, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
|
||||||
|
|
||||||
|
// Pick task from the given sub thread pool requests in a round robin fashion.
|
||||||
|
mutex_lock l(mu);
|
||||||
|
for (int round = 0; round < 2; ++round) {
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
ok_to_execute = true;
|
||||||
|
function_start.notify_one();
|
||||||
|
while (!ok_to_validate) {
|
||||||
|
function_end.wait(l);
|
||||||
|
}
|
||||||
|
ok_to_validate = false;
|
||||||
|
EXPECT_EQ(result, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pick task from any task if there is no tasks from the requests in the sub
|
||||||
|
// thread pool.
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int round = 0; round < 2; ++round) {
|
||||||
|
ok_to_execute = true;
|
||||||
|
function_start.notify_one();
|
||||||
|
while (!ok_to_validate) {
|
||||||
|
function_end.wait(l);
|
||||||
|
}
|
||||||
|
ok_to_validate = false;
|
||||||
|
EXPECT_EQ(result, i + 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
delete run_handler_thread_pool;
|
||||||
|
}
|
||||||
|
|
||||||
SessionOptions DefaultSessionOptions() {
|
SessionOptions DefaultSessionOptions() {
|
||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user