Refactor code for run handler thread pool, and add some unit tests.
PiperOrigin-RevId: 300835548 Change-Id: Ieb24ef54a327657af6452aaad893cb39fa02bd1c
This commit is contained in:
parent
8e80c097df
commit
f4cd9ee784
@ -41,28 +41,18 @@ namespace {
|
||||
static constexpr int32 kMaxConcurrentHandlers = 128;
|
||||
// LINT.ThenChange(//tensorflow/core/framework/run_handler_test.cc)
|
||||
|
||||
// TODO(azaks): Refactor with thread:ThreadPool
|
||||
class RunHandlerEnvironment {
|
||||
typedef Thread EnvThread;
|
||||
struct TaskImpl {
|
||||
std::function<void()> f;
|
||||
Context context;
|
||||
uint64 trace_id;
|
||||
};
|
||||
Env* const env_;
|
||||
const ThreadOptions thread_options_;
|
||||
const string name_;
|
||||
typedef typename internal::RunHandlerEnvironment::Task Task;
|
||||
typedef Eigen::RunQueue<Task, 1024> Queue;
|
||||
|
||||
public:
|
||||
struct Task {
|
||||
std::unique_ptr<TaskImpl> f;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options,
|
||||
const string& name)
|
||||
namespace internal {
|
||||
RunHandlerEnvironment::RunHandlerEnvironment(
|
||||
Env* env, const ThreadOptions& thread_options, const string& name)
|
||||
: env_(env), thread_options_(thread_options), name_(name) {}
|
||||
|
||||
EnvThread* CreateThread(std::function<void()> f) {
|
||||
RunHandlerEnvironment::EnvThread* RunHandlerEnvironment::CreateThread(
|
||||
std::function<void()> f) {
|
||||
return env_->StartThread(thread_options_, name_, [=]() {
|
||||
// Set the processor flag to flush denormals to zero.
|
||||
port::ScopedFlushDenormal flush;
|
||||
@ -75,7 +65,8 @@ class RunHandlerEnvironment {
|
||||
});
|
||||
}
|
||||
|
||||
Task CreateTask(std::function<void()> f) {
|
||||
RunHandlerEnvironment::Task RunHandlerEnvironment::CreateTask(
|
||||
std::function<void()> f) {
|
||||
uint64 id = 0;
|
||||
if (tracing::EventCollector::IsEnabled()) {
|
||||
id = tracing::GetUniqueArg();
|
||||
@ -90,30 +81,12 @@ class RunHandlerEnvironment {
|
||||
};
|
||||
}
|
||||
|
||||
void ExecuteTask(const Task& t) {
|
||||
void RunHandlerEnvironment::ExecuteTask(const Task& t) {
|
||||
WithContext wc(t.f->context);
|
||||
tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
|
||||
t.f->trace_id);
|
||||
t.f->f();
|
||||
}
|
||||
};
|
||||
|
||||
typedef typename RunHandlerEnvironment::Task Task;
|
||||
typedef Eigen::RunQueue<Task, 1024> Queue;
|
||||
|
||||
// To reduce cache misses, we use a doubly-linked list of Waiter structs and
|
||||
// queue them in LIFO order rather than the FIFO order used by a single
|
||||
// condition variable.
|
||||
struct Waiter {
|
||||
Waiter() {
|
||||
next = this;
|
||||
prev = this;
|
||||
}
|
||||
condition_variable cv;
|
||||
mutex mu;
|
||||
Waiter* next;
|
||||
Waiter* prev;
|
||||
};
|
||||
|
||||
void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
|
||||
int max_sleep_micros) {
|
||||
@ -150,9 +123,7 @@ void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
|
||||
}
|
||||
}
|
||||
|
||||
class ThreadWorkSource {
|
||||
public:
|
||||
ThreadWorkSource()
|
||||
ThreadWorkSource::ThreadWorkSource()
|
||||
: non_blocking_work_sharding_factor_(
|
||||
static_cast<int32>(ParamFromEnvWithDefault(
|
||||
"TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))),
|
||||
@ -169,13 +140,13 @@ class ThreadWorkSource {
|
||||
}
|
||||
}
|
||||
|
||||
~ThreadWorkSource() {
|
||||
ThreadWorkSource::~ThreadWorkSource() {
|
||||
for (int i = 0; i < non_blocking_work_queues_.size(); ++i) {
|
||||
delete non_blocking_work_queues_[i];
|
||||
}
|
||||
}
|
||||
|
||||
Task EnqueueTask(Task t, bool is_blocking) {
|
||||
Task ThreadWorkSource::EnqueueTask(Task t, bool is_blocking) {
|
||||
mutex* mu = nullptr;
|
||||
Queue* task_queue = nullptr;
|
||||
thread_local int64 closure_counter = 0;
|
||||
@ -196,8 +167,8 @@ class ThreadWorkSource {
|
||||
}
|
||||
|
||||
Waiter* w = nullptr;
|
||||
bool use_sub_thread_pool = ParamFromEnvBoolWithDefault(
|
||||
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
|
||||
bool use_sub_thread_pool =
|
||||
ParamFromEnvBoolWithDefault("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
|
||||
|
||||
Waiter* waiter_queue;
|
||||
mutex* waiter_queue_mu;
|
||||
@ -216,15 +187,14 @@ class ThreadWorkSource {
|
||||
waiter_queue = &queue_waiters_;
|
||||
waiter_queue_mu = &waiters_mu_;
|
||||
}
|
||||
|
||||
{
|
||||
mutex_lock l(*waiter_queue_mu);
|
||||
if (waiter_queue->next != waiter_queue) {
|
||||
// Remove waiter from the LIFO queue
|
||||
w = waiter_queue->next;
|
||||
|
||||
CHECK(w->prev != w);
|
||||
CHECK(w->next != w);
|
||||
CHECK(w->prev != w); // Crash OK.
|
||||
CHECK(w->next != w); // Crash OK.
|
||||
|
||||
w->next->prev = w->prev;
|
||||
w->prev->next = w->next;
|
||||
@ -246,9 +216,12 @@ class ThreadWorkSource {
|
||||
return t;
|
||||
}
|
||||
|
||||
Task PopBlockingTask() { return blocking_work_queue_.PopBack(); }
|
||||
Task ThreadWorkSource::PopBlockingTask() {
|
||||
return blocking_work_queue_.PopBack();
|
||||
}
|
||||
|
||||
Task PopNonBlockingTask(int start_index, bool search_from_all_queue) {
|
||||
Task ThreadWorkSource::PopNonBlockingTask(int start_index,
|
||||
bool search_from_all_queue) {
|
||||
Task t;
|
||||
unsigned sharding_factor = NonBlockingWorkShardingFactor();
|
||||
for (unsigned j = 0; j < sharding_factor; ++j) {
|
||||
@ -264,12 +237,12 @@ class ThreadWorkSource {
|
||||
return t;
|
||||
}
|
||||
|
||||
void WaitForWork(int max_sleep_micros) {
|
||||
void ThreadWorkSource::WaitForWork(int max_sleep_micros) {
|
||||
thread_local Waiter waiter;
|
||||
WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
|
||||
}
|
||||
|
||||
int TaskQueueSize(bool is_blocking) {
|
||||
int ThreadWorkSource::TaskQueueSize(bool is_blocking) {
|
||||
if (is_blocking) {
|
||||
return blocking_work_queue_.Size();
|
||||
} else {
|
||||
@ -281,11 +254,13 @@ class ThreadWorkSource {
|
||||
}
|
||||
}
|
||||
|
||||
int64 GetTracemeId() { return traceme_id_.load(std::memory_order_relaxed); }
|
||||
int64 ThreadWorkSource::GetTracemeId() {
|
||||
return traceme_id_.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void SetTracemeId(int64 value) { traceme_id_ = value; }
|
||||
void ThreadWorkSource::SetTracemeId(int64 value) { traceme_id_ = value; }
|
||||
|
||||
void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) {
|
||||
void ThreadWorkSource::SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) {
|
||||
{
|
||||
tf_shared_lock lock(run_handler_waiter_mu_);
|
||||
// Most of the request won't change sub pool for recomputation.
|
||||
@ -305,29 +280,29 @@ class ThreadWorkSource {
|
||||
version_ = version;
|
||||
}
|
||||
|
||||
int64 GetInflightTaskCount(bool is_blocking) {
|
||||
int64 ThreadWorkSource::GetInflightTaskCount(bool is_blocking) {
|
||||
std::atomic<int64>* counter =
|
||||
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
|
||||
return counter->load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void IncrementInflightTaskCount(bool is_blocking) {
|
||||
void ThreadWorkSource::IncrementInflightTaskCount(bool is_blocking) {
|
||||
std::atomic<int64>* counter =
|
||||
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
|
||||
counter->fetch_add(1, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void DecrementInflightTaskCount(bool is_blocking) {
|
||||
void ThreadWorkSource::DecrementInflightTaskCount(bool is_blocking) {
|
||||
std::atomic<int64>* counter =
|
||||
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
|
||||
counter->fetch_sub(1, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
unsigned NonBlockingWorkShardingFactor() {
|
||||
unsigned ThreadWorkSource::NonBlockingWorkShardingFactor() {
|
||||
return non_blocking_work_sharding_factor_;
|
||||
}
|
||||
|
||||
std::string ToString() {
|
||||
std::string ThreadWorkSource::ToString() {
|
||||
return strings::StrCat("traceme_id = ", GetTracemeId(),
|
||||
", inter queue size = ", TaskQueueSize(true),
|
||||
", inter inflight = ", GetInflightTaskCount(true),
|
||||
@ -335,43 +310,9 @@ class ThreadWorkSource {
|
||||
", intra inflight = ", GetInflightTaskCount(false));
|
||||
}
|
||||
|
||||
private:
|
||||
struct NonBlockingQueue {
|
||||
mutex queue_op_mu;
|
||||
char pad[128];
|
||||
Queue queue;
|
||||
};
|
||||
|
||||
int32 non_blocking_work_sharding_factor_;
|
||||
Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_;
|
||||
|
||||
std::atomic<int64> blocking_inflight_;
|
||||
std::atomic<int64> non_blocking_inflight_;
|
||||
|
||||
Queue blocking_work_queue_;
|
||||
mutex blocking_queue_op_mu_;
|
||||
char pad_[128];
|
||||
mutex waiters_mu_;
|
||||
Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_);
|
||||
std::atomic<int64> traceme_id_;
|
||||
|
||||
mutex run_handler_waiter_mu_;
|
||||
uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_);
|
||||
mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_);
|
||||
Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_);
|
||||
};
|
||||
|
||||
class RunHandlerThreadPool {
|
||||
public:
|
||||
struct PerThread {
|
||||
constexpr PerThread() : pool(nullptr), thread_id(-1) {}
|
||||
RunHandlerThreadPool* pool; // Parent pool, or null for normal threads.
|
||||
int thread_id; // Worker thread index in pool.
|
||||
};
|
||||
|
||||
RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
|
||||
Env* env, const ThreadOptions& thread_options,
|
||||
const string& name,
|
||||
RunHandlerThreadPool::RunHandlerThreadPool(
|
||||
int num_blocking_threads, int num_non_blocking_threads, Env* env,
|
||||
const ThreadOptions& thread_options, const string& name,
|
||||
Eigen::MaxSizeVector<mutex>* waiters_mu,
|
||||
Eigen::MaxSizeVector<Waiter>* queue_waiters)
|
||||
: num_threads_(num_blocking_threads + num_non_blocking_threads),
|
||||
@ -386,8 +327,7 @@ class RunHandlerThreadPool {
|
||||
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
|
||||
num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
|
||||
"TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
|
||||
std::vector<int>(
|
||||
{num_blocking_threads / 2,
|
||||
std::vector<int>({num_blocking_threads / 2,
|
||||
num_blocking_threads - num_blocking_threads / 2}))),
|
||||
sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
|
||||
"TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
|
||||
@ -395,12 +335,13 @@ class RunHandlerThreadPool {
|
||||
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
|
||||
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
|
||||
std::vector<double>({0.4, 1}))) {
|
||||
thread_data_.resize(num_threads_);
|
||||
VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
|
||||
<< num_blocking_threads_ << " blocking threads and "
|
||||
<< num_non_blocking_threads_ << " non-blocking threads.";
|
||||
}
|
||||
|
||||
~RunHandlerThreadPool() {
|
||||
RunHandlerThreadPool::~RunHandlerThreadPool() {
|
||||
VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
|
||||
|
||||
cancelled_ = true;
|
||||
@ -413,9 +354,8 @@ class RunHandlerThreadPool {
|
||||
}
|
||||
}
|
||||
|
||||
void Start() {
|
||||
void RunHandlerThreadPool::Start() {
|
||||
cancelled_ = false;
|
||||
thread_data_.resize(num_threads_);
|
||||
int num_blocking_threads = num_blocking_threads_;
|
||||
for (int i = 0; i < num_threads_; i++) {
|
||||
int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
|
||||
@ -433,7 +373,15 @@ class RunHandlerThreadPool {
|
||||
}
|
||||
}
|
||||
|
||||
void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
|
||||
void RunHandlerThreadPool::StartOneThreadForTesting() {
|
||||
cancelled_ = false;
|
||||
thread_data_[0].sub_thread_pool_id = 0;
|
||||
thread_data_[0].thread.reset(
|
||||
env_.CreateThread([this]() { WorkerLoop(0, true); }));
|
||||
}
|
||||
|
||||
void RunHandlerThreadPool::AddWorkToQueue(ThreadWorkSource* tws,
|
||||
bool is_blocking,
|
||||
std::function<void()> fn) {
|
||||
Task t = env_.CreateTask(std::move(fn));
|
||||
t = tws->EnqueueTask(std::move(t), is_blocking);
|
||||
@ -444,16 +392,12 @@ class RunHandlerThreadPool {
|
||||
}
|
||||
}
|
||||
|
||||
// Set work queues from which the thread 'tid' can steal its work.
|
||||
// The request with start_request_idx will be attempted first. Other requests
|
||||
// will be attempted in FIFO order based on their arrival time.
|
||||
|
||||
// TODO(donglin) Change the task steal order to be round-robin such that if
|
||||
// an attempt to steal task from request i failed, then attempt to steal task
|
||||
// from the next request in terms of the arrival time. This approach may
|
||||
// provide better performance due to less lock retention. The drawback is that
|
||||
// the profiler will be a bit harder to read.
|
||||
void SetThreadWorkSources(
|
||||
void RunHandlerThreadPool::SetThreadWorkSources(
|
||||
int tid, int start_request_idx, uint64 version,
|
||||
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
|
||||
mutex_lock l(thread_data_[tid].mu);
|
||||
@ -464,7 +408,6 @@ class RunHandlerThreadPool {
|
||||
return;
|
||||
}
|
||||
thread_data_[tid].new_thread_work_sources->resize(0);
|
||||
|
||||
if (use_sub_thread_pool_) {
|
||||
for (int i = 0; i < thread_work_sources.size(); ++i) {
|
||||
thread_data_[tid].new_thread_work_sources->emplace_back(
|
||||
@ -481,8 +424,7 @@ class RunHandlerThreadPool {
|
||||
// thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
|
||||
// 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
|
||||
// 4... for the other half of the threads.
|
||||
int num_shards =
|
||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
|
||||
int num_shards = ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
|
||||
int token = tid % num_shards;
|
||||
for (int i = 0; i < num_shards; ++i) {
|
||||
for (int j = token; j < thread_work_sources.size(); j += num_shards) {
|
||||
@ -497,15 +439,14 @@ class RunHandlerThreadPool {
|
||||
}
|
||||
}
|
||||
|
||||
PerThread* GetPerThread() {
|
||||
thread_local PerThread per_thread_;
|
||||
PerThread* pt = &per_thread_;
|
||||
RunHandlerThreadPool::PerThread* RunHandlerThreadPool::GetPerThread() {
|
||||
thread_local RunHandlerThreadPool::PerThread per_thread_;
|
||||
RunHandlerThreadPool::PerThread* pt = &per_thread_;
|
||||
return pt;
|
||||
}
|
||||
|
||||
int CurrentThreadId() const {
|
||||
const PerThread* pt =
|
||||
const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
|
||||
int RunHandlerThreadPool::CurrentThreadId() const {
|
||||
const PerThread* pt = const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
|
||||
if (pt->pool == this) {
|
||||
return pt->thread_id;
|
||||
} else {
|
||||
@ -513,79 +454,28 @@ class RunHandlerThreadPool {
|
||||
}
|
||||
}
|
||||
|
||||
int NumThreads() const { return num_threads_; }
|
||||
int RunHandlerThreadPool::NumThreads() const { return num_threads_; }
|
||||
|
||||
int NumBlockingThreads() const { return num_blocking_threads_; }
|
||||
int RunHandlerThreadPool::NumBlockingThreads() const {
|
||||
return num_blocking_threads_;
|
||||
}
|
||||
|
||||
int NumNonBlockingThreads() const { return num_non_blocking_threads_; }
|
||||
int RunHandlerThreadPool::NumNonBlockingThreads() const {
|
||||
return num_non_blocking_threads_;
|
||||
}
|
||||
|
||||
void WorkerLoop(int thread_id, bool may_steal_blocking_work);
|
||||
|
||||
// Search tasks from Requets range searching_range_start to
|
||||
// searching_range_end. If there is no tasks in the search range and
|
||||
// may_steal_blocking_work is true, then search from all requests.
|
||||
Task FindTask(
|
||||
int searching_range_start, int searching_range_end, int thread_id,
|
||||
int sub_thread_pool_id, int max_blocking_inflight,
|
||||
bool may_steal_blocking_work,
|
||||
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
|
||||
bool* task_from_blocking_queue, ThreadWorkSource** tws);
|
||||
|
||||
void WaitForWork(bool is_blocking, int thread_id,
|
||||
int32 max_blocking_inflight);
|
||||
|
||||
void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);
|
||||
|
||||
private:
|
||||
struct ThreadData {
|
||||
ThreadData()
|
||||
RunHandlerThreadPool::ThreadData::ThreadData()
|
||||
: new_version(0),
|
||||
current_index(0),
|
||||
new_thread_work_sources(new Eigen::MaxSizeVector<ThreadWorkSource*>(
|
||||
static_cast<int32>(ParamFromEnvWithDefault(
|
||||
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
||||
new_thread_work_sources(
|
||||
new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
|
||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
||||
kMaxConcurrentHandlers)))),
|
||||
current_version(0),
|
||||
current_thread_work_sources(
|
||||
new Eigen::MaxSizeVector<ThreadWorkSource*>(
|
||||
static_cast<int32>(ParamFromEnvWithDefault(
|
||||
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
||||
new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
|
||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
||||
kMaxConcurrentHandlers)))) {}
|
||||
mutex mu;
|
||||
uint64 new_version;
|
||||
condition_variable sources_not_empty;
|
||||
std::unique_ptr<Thread> thread;
|
||||
int current_index;
|
||||
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
|
||||
new_thread_work_sources TF_GUARDED_BY(mu);
|
||||
|
||||
uint64 current_version;
|
||||
// Should only be accessed by one thread.
|
||||
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
|
||||
current_thread_work_sources;
|
||||
|
||||
int sub_thread_pool_id;
|
||||
};
|
||||
|
||||
const int num_threads_;
|
||||
const int num_blocking_threads_;
|
||||
const int num_non_blocking_threads_;
|
||||
Eigen::MaxSizeVector<ThreadData> thread_data_;
|
||||
RunHandlerEnvironment env_;
|
||||
std::atomic<bool> cancelled_;
|
||||
string name_;
|
||||
Eigen::MaxSizeVector<mutex>* waiters_mu_;
|
||||
Eigen::MaxSizeVector<Waiter>* queue_waiters_;
|
||||
|
||||
bool use_sub_thread_pool_;
|
||||
std::vector<int> num_threads_in_sub_thread_pool_;
|
||||
|
||||
// Threads in each sub thread pool will search tasks from the given
|
||||
// start_request_percentage to end_request_percentage in a round robin
|
||||
// fashion.
|
||||
std::vector<double> sub_thread_pool_start_request_percentage_;
|
||||
std::vector<double> sub_thread_pool_end_request_percentage_;
|
||||
};
|
||||
|
||||
Task RunHandlerThreadPool::FindTask(
|
||||
int searching_range_start, int searching_range_end, int thread_id,
|
||||
@ -597,10 +487,9 @@ Task RunHandlerThreadPool::FindTask(
|
||||
int current_index = thread_data_[thread_id].current_index;
|
||||
*task_from_blocking_queue = false;
|
||||
|
||||
// TODO(chaox): Change the search algorithm from round robin to random
|
||||
// walk.
|
||||
for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
|
||||
if (current_index >= searching_range_end) {
|
||||
if (current_index >= searching_range_end ||
|
||||
current_index < searching_range_start) {
|
||||
current_index = searching_range_start;
|
||||
}
|
||||
*tws = thread_work_sources[current_index];
|
||||
@ -821,7 +710,7 @@ void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id,
|
||||
tws->WaitForWork(kMaxSleepMicros);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
|
||||
// Contains the concrete implementation of the RunHandler.
|
||||
// Externally visible RunHandler class simply forwards the work to this one.
|
||||
@ -847,7 +736,7 @@ class RunHandler::Impl {
|
||||
|
||||
RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
|
||||
|
||||
ThreadWorkSource* tws() { return &tws_; }
|
||||
internal::ThreadWorkSource* tws() { return &tws_; }
|
||||
|
||||
int64 priority() { return options_.priority(); }
|
||||
|
||||
@ -869,7 +758,7 @@ class RunHandler::Impl {
|
||||
uint64 start_time_us_;
|
||||
int64 step_id_;
|
||||
std::unique_ptr<thread::ThreadPoolInterface> thread_pool_interface_;
|
||||
ThreadWorkSource tws_;
|
||||
internal::ThreadWorkSource tws_;
|
||||
RunOptions::Experimental::RunHandlerPoolOptions options_;
|
||||
};
|
||||
|
||||
@ -885,7 +774,7 @@ class RunHandlerPool::Impl {
|
||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
|
||||
queue_waiters_(
|
||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
|
||||
run_handler_thread_pool_(new RunHandlerThreadPool(
|
||||
run_handler_thread_pool_(new internal::RunHandlerThreadPool(
|
||||
num_inter_op_threads, num_intra_op_threads, Env::Default(),
|
||||
ThreadOptions(), "tf_run_handler_pool", &waiters_mu_,
|
||||
&queue_waiters_)),
|
||||
@ -924,7 +813,7 @@ class RunHandlerPool::Impl {
|
||||
run_handler_thread_pool_.reset();
|
||||
}
|
||||
|
||||
RunHandlerThreadPool* run_handler_thread_pool() {
|
||||
internal::RunHandlerThreadPool* run_handler_thread_pool() {
|
||||
return run_handler_thread_pool_.get();
|
||||
}
|
||||
|
||||
@ -936,10 +825,11 @@ class RunHandlerPool::Impl {
|
||||
int64 step_id, int64 timeout_in_ms,
|
||||
const RunOptions::Experimental::RunHandlerPoolOptions& options)
|
||||
TF_LOCKS_EXCLUDED(mu_) {
|
||||
thread_local std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
|
||||
thread_local std::unique_ptr<
|
||||
Eigen::MaxSizeVector<internal::ThreadWorkSource*>>
|
||||
thread_work_sources =
|
||||
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>(
|
||||
new Eigen::MaxSizeVector<ThreadWorkSource*>(
|
||||
std::unique_ptr<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>(
|
||||
new Eigen::MaxSizeVector<internal::ThreadWorkSource*>(
|
||||
static_cast<int32>(ParamFromEnvWithDefault(
|
||||
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
||||
kMaxConcurrentHandlers))));
|
||||
@ -1035,7 +925,8 @@ class RunHandlerPool::Impl {
|
||||
private:
|
||||
void RecomputePoolStats(
|
||||
int num_active_requests, uint64 version,
|
||||
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources);
|
||||
const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
|
||||
thread_work_sources);
|
||||
|
||||
void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
@ -1046,9 +937,9 @@ class RunHandlerPool::Impl {
|
||||
const int max_handlers_;
|
||||
|
||||
Eigen::MaxSizeVector<mutex> waiters_mu_;
|
||||
Eigen::MaxSizeVector<Waiter> queue_waiters_;
|
||||
Eigen::MaxSizeVector<internal::Waiter> queue_waiters_;
|
||||
|
||||
std::unique_ptr<RunHandlerThreadPool> run_handler_thread_pool_;
|
||||
std::unique_ptr<internal::RunHandlerThreadPool> run_handler_thread_pool_;
|
||||
// Thread compatible part used only by lock under RunHandlerPool.
|
||||
// Handlers are sorted by start time.
|
||||
// TODO(azaks): sort by the remaining latency budget.
|
||||
@ -1070,7 +961,8 @@ class RunHandlerPool::Impl {
|
||||
|
||||
void RunHandlerPool::Impl::RecomputePoolStats(
|
||||
int num_active_requests, uint64 version,
|
||||
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
|
||||
const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
|
||||
thread_work_sources) {
|
||||
if (num_active_requests == 0) return;
|
||||
|
||||
int sub_thread_pool_id = 0;
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/histogram/histogram.h"
|
||||
#include "tensorflow/core/platform/context.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
@ -106,6 +107,208 @@ class RunHandler {
|
||||
Impl* impl_; // NOT OWNED.
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
|
||||
// TODO(azaks): Refactor with thread:ThreadPool
|
||||
class RunHandlerEnvironment {
|
||||
typedef Thread EnvThread;
|
||||
struct TaskImpl {
|
||||
std::function<void()> f;
|
||||
Context context;
|
||||
uint64 trace_id;
|
||||
};
|
||||
Env* const env_;
|
||||
const ThreadOptions thread_options_;
|
||||
const string name_;
|
||||
|
||||
public:
|
||||
struct Task {
|
||||
std::unique_ptr<TaskImpl> f;
|
||||
};
|
||||
|
||||
RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options,
|
||||
const string& name);
|
||||
|
||||
EnvThread* CreateThread(std::function<void()> f);
|
||||
|
||||
Task CreateTask(std::function<void()> f);
|
||||
|
||||
void ExecuteTask(const Task& t);
|
||||
};
|
||||
|
||||
typedef typename RunHandlerEnvironment::Task Task;
|
||||
typedef Eigen::RunQueue<Task, 1024> Queue;
|
||||
|
||||
// To reduce cache misses, we use a doubly-linked list of Waiter structs and
|
||||
// queue them in LIFO order rather than the FIFO order used by a single
|
||||
// condition variable.
|
||||
struct Waiter {
|
||||
Waiter() {
|
||||
next = this;
|
||||
prev = this;
|
||||
}
|
||||
condition_variable cv;
|
||||
mutex mu;
|
||||
Waiter* next;
|
||||
Waiter* prev;
|
||||
};
|
||||
|
||||
class ThreadWorkSource {
|
||||
public:
|
||||
ThreadWorkSource();
|
||||
|
||||
~ThreadWorkSource();
|
||||
|
||||
Task EnqueueTask(Task t, bool is_blocking);
|
||||
|
||||
Task PopBlockingTask();
|
||||
|
||||
Task PopNonBlockingTask(int start_index, bool search_from_all_queue);
|
||||
|
||||
void WaitForWork(int max_sleep_micros);
|
||||
|
||||
int TaskQueueSize(bool is_blocking);
|
||||
|
||||
int64 GetTracemeId();
|
||||
|
||||
void SetTracemeId(int64 value);
|
||||
|
||||
void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex);
|
||||
|
||||
int64 GetInflightTaskCount(bool is_blocking);
|
||||
|
||||
void IncrementInflightTaskCount(bool is_blocking);
|
||||
|
||||
void DecrementInflightTaskCount(bool is_blocking);
|
||||
|
||||
unsigned NonBlockingWorkShardingFactor();
|
||||
|
||||
std::string ToString();
|
||||
|
||||
private:
|
||||
struct NonBlockingQueue {
|
||||
mutex queue_op_mu;
|
||||
char pad[128];
|
||||
Queue queue;
|
||||
};
|
||||
|
||||
int32 non_blocking_work_sharding_factor_;
|
||||
Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_;
|
||||
|
||||
std::atomic<int64> blocking_inflight_;
|
||||
std::atomic<int64> non_blocking_inflight_;
|
||||
|
||||
Queue blocking_work_queue_;
|
||||
mutex blocking_queue_op_mu_;
|
||||
char pad_[128];
|
||||
mutex waiters_mu_;
|
||||
Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_);
|
||||
std::atomic<int64> traceme_id_;
|
||||
|
||||
mutex run_handler_waiter_mu_;
|
||||
uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_);
|
||||
mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_);
|
||||
Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_);
|
||||
};
|
||||
|
||||
class RunHandlerThreadPool {
|
||||
public:
|
||||
struct PerThread {
|
||||
constexpr PerThread() : pool(nullptr), thread_id(-1) {}
|
||||
RunHandlerThreadPool* pool; // Parent pool, or null for normal threads.
|
||||
int thread_id; // Worker thread index in pool.
|
||||
};
|
||||
|
||||
RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
|
||||
Env* env, const ThreadOptions& thread_options,
|
||||
const string& name,
|
||||
Eigen::MaxSizeVector<mutex>* waiters_mu,
|
||||
Eigen::MaxSizeVector<Waiter>* queue_waiters);
|
||||
|
||||
~RunHandlerThreadPool();
|
||||
|
||||
void Start();
|
||||
|
||||
void StartOneThreadForTesting();
|
||||
|
||||
void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
|
||||
std::function<void()> fn);
|
||||
|
||||
// Set work queues from which the thread 'tid' can steal its work.
|
||||
// The request with start_request_idx will be attempted first. Other requests
|
||||
// will be attempted in FIFO order based on their arrival time.
|
||||
void SetThreadWorkSources(
|
||||
int tid, int start_request_idx, uint64 version,
|
||||
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources);
|
||||
|
||||
PerThread* GetPerThread();
|
||||
|
||||
int CurrentThreadId() const;
|
||||
|
||||
int NumThreads() const;
|
||||
|
||||
int NumBlockingThreads() const;
|
||||
|
||||
int NumNonBlockingThreads() const;
|
||||
|
||||
void WorkerLoop(int thread_id, bool may_steal_blocking_work);
|
||||
|
||||
// Search tasks from Requets range searching_range_start to
|
||||
// searching_range_end. If there is no tasks in the search range and
|
||||
// may_steal_blocking_work is true, then search from all requests.
|
||||
Task FindTask(
|
||||
int searching_range_start, int searching_range_end, int thread_id,
|
||||
int sub_thread_pool_id, int max_blocking_inflight,
|
||||
bool may_steal_blocking_work,
|
||||
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
|
||||
bool* task_from_blocking_queue, ThreadWorkSource** tws);
|
||||
|
||||
void WaitForWork(bool is_blocking, int thread_id,
|
||||
int32 max_blocking_inflight);
|
||||
|
||||
void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);
|
||||
|
||||
private:
|
||||
struct ThreadData {
|
||||
ThreadData();
|
||||
mutex mu;
|
||||
uint64 new_version;
|
||||
condition_variable sources_not_empty;
|
||||
std::unique_ptr<Thread> thread;
|
||||
int current_index;
|
||||
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
|
||||
new_thread_work_sources TF_GUARDED_BY(mu);
|
||||
|
||||
uint64 current_version;
|
||||
// Should only be accessed by one thread.
|
||||
std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
|
||||
current_thread_work_sources;
|
||||
|
||||
int sub_thread_pool_id;
|
||||
};
|
||||
|
||||
const int num_threads_;
|
||||
const int num_blocking_threads_;
|
||||
const int num_non_blocking_threads_;
|
||||
Eigen::MaxSizeVector<ThreadData> thread_data_;
|
||||
internal::RunHandlerEnvironment env_;
|
||||
std::atomic<bool> cancelled_;
|
||||
string name_;
|
||||
Eigen::MaxSizeVector<mutex>* waiters_mu_;
|
||||
Eigen::MaxSizeVector<Waiter>* queue_waiters_;
|
||||
|
||||
bool use_sub_thread_pool_;
|
||||
std::vector<int> num_threads_in_sub_thread_pool_;
|
||||
|
||||
// Threads in each sub thread pool will search tasks from the given
|
||||
// start_request_percentage to end_request_percentage in a round robin
|
||||
// fashion.
|
||||
std::vector<double> sub_thread_pool_start_request_percentage_;
|
||||
std::vector<double> sub_thread_pool_end_request_percentage_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
} // end namespace tensorflow.
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
|
||||
|
@ -113,6 +113,476 @@ TEST(RunHandlerUtilTest, PrioritySchedulingTest) {
|
||||
EXPECT_EQ(sorted_active_list[3], 1);
|
||||
}
|
||||
|
||||
TEST(RunHandlerThreadPool, EnqueueTask) {
|
||||
Eigen::MaxSizeVector<mutex> waiters_mu(2);
|
||||
waiters_mu.resize(2);
|
||||
Eigen::MaxSizeVector<internal::Waiter> waiters(2);
|
||||
waiters.resize(2);
|
||||
internal::RunHandlerThreadPool run_handler_thread_pool(
|
||||
/*num_blocking_threads=*/0, /*num_non_blocking_threads=*/0,
|
||||
Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
|
||||
&waiters);
|
||||
internal::ThreadWorkSource tws;
|
||||
|
||||
int result = 0;
|
||||
std::function<void()> fn = [&result] { result = 1; };
|
||||
std::function<void()> fn2 = [&result] { result = 2; };
|
||||
run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/true, fn);
|
||||
EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/true), 1);
|
||||
run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/true, fn2);
|
||||
EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/true), 2);
|
||||
tws.PopBlockingTask().f->f();
|
||||
EXPECT_EQ(result, 1);
|
||||
tws.PopBlockingTask().f->f();
|
||||
EXPECT_EQ(result, 2);
|
||||
|
||||
run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/false, fn);
|
||||
EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/false), 1);
|
||||
run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/false, fn2);
|
||||
EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/false), 2);
|
||||
tws.PopNonBlockingTask(0, true).f->f();
|
||||
EXPECT_EQ(result, 1);
|
||||
tws.PopNonBlockingTask(0, true).f->f();
|
||||
EXPECT_EQ(result, 2);
|
||||
}
|
||||
|
||||
TEST(RunHandlerThreadPool, FindTask) {
|
||||
Eigen::MaxSizeVector<mutex> waiters_mu(2);
|
||||
waiters_mu.resize(2);
|
||||
Eigen::MaxSizeVector<internal::Waiter> waiters(2);
|
||||
waiters.resize(2);
|
||||
internal::RunHandlerThreadPool run_handler_thread_pool(
|
||||
/*num_blocking_threads=*/1, /*num_non_blocking_threads=*/0,
|
||||
Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
|
||||
&waiters);
|
||||
|
||||
Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(5);
|
||||
thread_work_sources.resize(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
thread_work_sources[i] = new internal::ThreadWorkSource();
|
||||
}
|
||||
|
||||
{
|
||||
// The thread should search the task following round robin fashion.
|
||||
int result = -1;
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 2; });
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 2; });
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 3; });
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 3; });
|
||||
|
||||
const auto find_blocking_task_from_all_handlers =
|
||||
[&](bool* task_from_blocking_queue, internal::Task* t) {
|
||||
internal::ThreadWorkSource* tws;
|
||||
*t = run_handler_thread_pool.FindTask(
|
||||
/*searching_range_start=*/0, /*searching_range_end=*/5,
|
||||
/*thread_id=*/0,
|
||||
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||
/*may_steal_blocking_work=*/true, thread_work_sources,
|
||||
task_from_blocking_queue, &tws);
|
||||
};
|
||||
bool task_from_blocking_queue;
|
||||
internal::Task t;
|
||||
find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
|
||||
EXPECT_EQ(task_from_blocking_queue, true);
|
||||
t.f->f();
|
||||
EXPECT_EQ(result, 2);
|
||||
|
||||
find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
|
||||
EXPECT_EQ(task_from_blocking_queue, true);
|
||||
t.f->f();
|
||||
EXPECT_EQ(result, 3);
|
||||
|
||||
find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
|
||||
EXPECT_EQ(task_from_blocking_queue, true);
|
||||
t.f->f();
|
||||
EXPECT_EQ(result, 2);
|
||||
|
||||
find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
|
||||
EXPECT_EQ(task_from_blocking_queue, true);
|
||||
t.f->f();
|
||||
EXPECT_EQ(result, 3);
|
||||
}
|
||||
|
||||
{
|
||||
// Task out of searching range cannot be found.
|
||||
int result = -1;
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 3; });
|
||||
|
||||
const auto find_blocking_task_from_range =
|
||||
[&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
|
||||
int range_end) {
|
||||
internal::ThreadWorkSource* tws;
|
||||
*t = run_handler_thread_pool.FindTask(
|
||||
range_start, range_end,
|
||||
/*thread_id=*/0,
|
||||
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||
/*may_steal_blocking_work=*/true, thread_work_sources,
|
||||
task_from_blocking_queue, &tws);
|
||||
};
|
||||
|
||||
bool task_from_blocking_queue;
|
||||
internal::Task t;
|
||||
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 3);
|
||||
EXPECT_EQ(t.f, nullptr);
|
||||
|
||||
// Clean up the queue.
|
||||
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
|
||||
}
|
||||
|
||||
{
|
||||
// The thread should search from start range if the currrent index is
|
||||
// smaller.
|
||||
int result = -1;
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 2; });
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 3; });
|
||||
|
||||
const auto find_blocking_task_from_range =
|
||||
[&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
|
||||
int range_end) {
|
||||
internal::ThreadWorkSource* tws;
|
||||
*t = run_handler_thread_pool.FindTask(
|
||||
range_start, range_end,
|
||||
/*thread_id=*/0,
|
||||
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||
/*may_steal_blocking_work=*/true, thread_work_sources,
|
||||
task_from_blocking_queue, &tws);
|
||||
};
|
||||
bool task_from_blocking_queue;
|
||||
internal::Task t;
|
||||
find_blocking_task_from_range(&task_from_blocking_queue, &t, 3, 5);
|
||||
EXPECT_EQ(task_from_blocking_queue, true);
|
||||
t.f->f();
|
||||
EXPECT_EQ(result, 3);
|
||||
|
||||
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
|
||||
EXPECT_EQ(task_from_blocking_queue, true);
|
||||
t.f->f();
|
||||
EXPECT_EQ(result, 2);
|
||||
}
|
||||
|
||||
{
|
||||
// The thread should search within the range even if the current index
|
||||
// is larger than searching_range_end;
|
||||
int result = -1;
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 2; });
|
||||
|
||||
const auto find_blocking_task_from_range =
|
||||
[&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
|
||||
int range_end) {
|
||||
internal::ThreadWorkSource* tws;
|
||||
*t = run_handler_thread_pool.FindTask(
|
||||
range_start, range_end,
|
||||
/*thread_id=*/0,
|
||||
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||
/*may_steal_blocking_work=*/true, thread_work_sources,
|
||||
task_from_blocking_queue, &tws);
|
||||
};
|
||||
bool task_from_blocking_queue;
|
||||
// Make the current index to be 3.
|
||||
internal::Task t;
|
||||
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
|
||||
EXPECT_EQ(task_from_blocking_queue, true);
|
||||
t.f->f();
|
||||
EXPECT_EQ(result, 2);
|
||||
|
||||
// Search in a smaller range.
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 2; });
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 3; });
|
||||
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 3);
|
||||
EXPECT_EQ(task_from_blocking_queue, true);
|
||||
t.f->f();
|
||||
EXPECT_EQ(result, 2);
|
||||
|
||||
// Clean up the queue.
|
||||
find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
|
||||
EXPECT_EQ(task_from_blocking_queue, true);
|
||||
t.f->f();
|
||||
EXPECT_EQ(result, 3);
|
||||
}
|
||||
|
||||
{
|
||||
// We prefer blocking task for blocking threads.
|
||||
int result = -1;
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||
/*is_blocking=*/false,
|
||||
[&result] { result = 2; });
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 2; });
|
||||
const auto blocking_thread_find_task_from_all_handler =
|
||||
[&](bool* task_from_blocking_queue, internal::Task* t) {
|
||||
internal::ThreadWorkSource* tws;
|
||||
*t = run_handler_thread_pool.FindTask(
|
||||
/*searching_range_start=*/0, /*searching_range_end=*/5,
|
||||
/*thread_id=*/0,
|
||||
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||
/*may_steal_blocking_work=*/true, thread_work_sources,
|
||||
task_from_blocking_queue, &tws);
|
||||
};
|
||||
bool task_from_blocking_queue;
|
||||
internal::Task t;
|
||||
blocking_thread_find_task_from_all_handler(&task_from_blocking_queue, &t);
|
||||
EXPECT_EQ(task_from_blocking_queue, true);
|
||||
t.f->f();
|
||||
EXPECT_EQ(result, 2);
|
||||
|
||||
blocking_thread_find_task_from_all_handler(&task_from_blocking_queue, &t);
|
||||
EXPECT_EQ(task_from_blocking_queue, false);
|
||||
t.f->f();
|
||||
EXPECT_EQ(result, 2);
|
||||
}
|
||||
|
||||
{
|
||||
// Nonblocking threads can only pick up non-blocking task.
|
||||
int result = -1;
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||
/*is_blocking=*/false,
|
||||
[&result] { result = 2; });
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 2; });
|
||||
|
||||
const auto find_task_from_all_handler = [&](bool* task_from_blocking_queue,
|
||||
internal::Task* t,
|
||||
bool is_blocking_thread) {
|
||||
internal::ThreadWorkSource* tws;
|
||||
*t = run_handler_thread_pool.FindTask(
|
||||
/*searching_range_start=*/0, /*searching_range_end=*/5,
|
||||
/*thread_id=*/0,
|
||||
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||
is_blocking_thread, thread_work_sources, task_from_blocking_queue,
|
||||
&tws);
|
||||
};
|
||||
bool task_from_blocking_queue;
|
||||
internal::Task t;
|
||||
find_task_from_all_handler(&task_from_blocking_queue, &t,
|
||||
/*is_blocking_thread=*/false);
|
||||
EXPECT_EQ(task_from_blocking_queue, false);
|
||||
t.f->f();
|
||||
EXPECT_EQ(result, 2);
|
||||
|
||||
find_task_from_all_handler(&task_from_blocking_queue, &t,
|
||||
/*is_blocking_thread=*/false);
|
||||
EXPECT_EQ(t.f, nullptr);
|
||||
|
||||
// Clean up the queue.
|
||||
find_task_from_all_handler(&task_from_blocking_queue, &t,
|
||||
/*is_blocking_thread=*/true);
|
||||
}
|
||||
|
||||
{
|
||||
// There is a limit for max_blocking_inflight requests.
|
||||
int result = -1;
|
||||
run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
|
||||
/*is_blocking=*/true,
|
||||
[&result] { result = 2; });
|
||||
|
||||
const auto find_task_from_all_handler = [&](bool* task_from_blocking_queue,
|
||||
internal::Task* t,
|
||||
bool is_blocking_thread) {
|
||||
internal::ThreadWorkSource* tws;
|
||||
*t = run_handler_thread_pool.FindTask(
|
||||
/*searching_range_start=*/0, /*searching_range_end=*/5,
|
||||
/*thread_id=*/0,
|
||||
/*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
|
||||
is_blocking_thread, thread_work_sources, task_from_blocking_queue,
|
||||
&tws);
|
||||
};
|
||||
|
||||
bool task_from_blocking_queue;
|
||||
internal::Task t;
|
||||
find_task_from_all_handler(&task_from_blocking_queue, &t,
|
||||
/*is_blocking_thread=*/false);
|
||||
EXPECT_EQ(task_from_blocking_queue, false);
|
||||
EXPECT_EQ(t.f, nullptr);
|
||||
|
||||
// Clean up the queue.
|
||||
find_task_from_all_handler(&task_from_blocking_queue, &t,
|
||||
/*is_blocking_thread=*/true);
|
||||
}
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
delete thread_work_sources[i];
|
||||
}
|
||||
}
|
||||
|
||||
TEST(RunHandlerThreadPool, RoundRobinExecution) {
|
||||
// Set up environment for 1 sub thread pool.
|
||||
setenv("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", "true", true);
|
||||
setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "1", true);
|
||||
setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", "0", true);
|
||||
setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", "1", true);
|
||||
|
||||
Eigen::MaxSizeVector<mutex> waiters_mu(1);
|
||||
waiters_mu.resize(1);
|
||||
Eigen::MaxSizeVector<internal::Waiter> waiters(1);
|
||||
waiters.resize(1);
|
||||
internal::RunHandlerThreadPool* run_handler_thread_pool =
|
||||
new internal::RunHandlerThreadPool(
|
||||
/*num_blocking_threads=*/1, /*num_non_blocking_threads=*/0,
|
||||
Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
|
||||
&waiters);
|
||||
Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(3);
|
||||
thread_work_sources.resize(3);
|
||||
internal::ThreadWorkSource tws[3];
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
tws[i].SetWaiter(1, &waiters[0], &waiters_mu[0]);
|
||||
thread_work_sources[i] = &tws[i];
|
||||
}
|
||||
|
||||
int result = 0;
|
||||
mutex mu;
|
||||
bool ok_to_execute = false;
|
||||
bool ok_to_validate = false;
|
||||
condition_variable function_start;
|
||||
condition_variable function_end;
|
||||
std::vector<std::function<void()>> fns;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
fns.push_back([&result, &mu, &function_start, &function_end, &ok_to_execute,
|
||||
&ok_to_validate, i] {
|
||||
mutex_lock l(mu);
|
||||
while (!ok_to_execute) {
|
||||
function_start.wait(l);
|
||||
}
|
||||
result = i;
|
||||
ok_to_execute = false;
|
||||
ok_to_validate = true;
|
||||
function_end.notify_one();
|
||||
});
|
||||
run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
|
||||
fns[i]);
|
||||
run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
|
||||
fns[i]);
|
||||
}
|
||||
run_handler_thread_pool->Start();
|
||||
run_handler_thread_pool->SetThreadWorkSources(
|
||||
/*tid=*/0, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
|
||||
|
||||
// Validate the execution should be roundrobin.
|
||||
mutex_lock l(mu);
|
||||
for (int round = 0; round < 2; ++round) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
ok_to_execute = true;
|
||||
function_start.notify_one();
|
||||
while (!ok_to_validate) {
|
||||
function_end.wait(l);
|
||||
}
|
||||
ok_to_validate = false;
|
||||
EXPECT_EQ(result, i);
|
||||
}
|
||||
}
|
||||
|
||||
delete run_handler_thread_pool;
|
||||
}
|
||||
|
||||
TEST(RunHandlerThreadPool, MultipleSubThreadPool) {
|
||||
// Set up environment for 2 sub thread pools.
|
||||
setenv("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", "true", true);
|
||||
setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "2", true);
|
||||
setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", "0,0.5",
|
||||
true);
|
||||
setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", "0.5,1",
|
||||
true);
|
||||
|
||||
Eigen::MaxSizeVector<mutex> waiters_mu(2);
|
||||
waiters_mu.resize(2);
|
||||
Eigen::MaxSizeVector<internal::Waiter> waiters(2);
|
||||
waiters.resize(2);
|
||||
internal::RunHandlerThreadPool* run_handler_thread_pool =
|
||||
new internal::RunHandlerThreadPool(
|
||||
/*num_blocking_threads=*/2, /*num_non_blocking_threads=*/0,
|
||||
Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
|
||||
&waiters);
|
||||
Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(4);
|
||||
thread_work_sources.resize(4);
|
||||
internal::ThreadWorkSource tws[4];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
tws[i].SetWaiter(1, &waiters[i / 2], &waiters_mu[i / 2]);
|
||||
thread_work_sources[i] = &tws[i];
|
||||
}
|
||||
|
||||
int result = 0;
|
||||
mutex mu;
|
||||
bool ok_to_execute = false;
|
||||
bool ok_to_validate = false;
|
||||
condition_variable function_start;
|
||||
condition_variable function_end;
|
||||
|
||||
std::vector<std::function<void()>> fns;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
fns.push_back([&result, &mu, &function_start, &function_end, &ok_to_execute,
|
||||
&ok_to_validate, i] {
|
||||
mutex_lock l(mu);
|
||||
while (!ok_to_execute) {
|
||||
function_start.wait(l);
|
||||
}
|
||||
result = i;
|
||||
ok_to_execute = false;
|
||||
ok_to_validate = true;
|
||||
function_end.notify_one();
|
||||
});
|
||||
run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
|
||||
fns[i]);
|
||||
run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
|
||||
fns[i]);
|
||||
}
|
||||
run_handler_thread_pool->StartOneThreadForTesting();
|
||||
run_handler_thread_pool->SetThreadWorkSources(
|
||||
/*tid=*/0, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
|
||||
run_handler_thread_pool->SetThreadWorkSources(
|
||||
/*tid=*/1, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
|
||||
|
||||
// Pick task from the given sub thread pool requests in a round robin fashion.
|
||||
mutex_lock l(mu);
|
||||
for (int round = 0; round < 2; ++round) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
ok_to_execute = true;
|
||||
function_start.notify_one();
|
||||
while (!ok_to_validate) {
|
||||
function_end.wait(l);
|
||||
}
|
||||
ok_to_validate = false;
|
||||
EXPECT_EQ(result, i);
|
||||
}
|
||||
}
|
||||
|
||||
// Pick task from any task if there is no tasks from the requests in the sub
|
||||
// thread pool.
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int round = 0; round < 2; ++round) {
|
||||
ok_to_execute = true;
|
||||
function_start.notify_one();
|
||||
while (!ok_to_validate) {
|
||||
function_end.wait(l);
|
||||
}
|
||||
ok_to_validate = false;
|
||||
EXPECT_EQ(result, i + 2);
|
||||
}
|
||||
}
|
||||
|
||||
delete run_handler_thread_pool;
|
||||
}
|
||||
|
||||
SessionOptions DefaultSessionOptions() {
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
|
Loading…
x
Reference in New Issue
Block a user