diff --git a/tensorflow/core/framework/run_handler.cc b/tensorflow/core/framework/run_handler.cc index b912554b9d4..c0ab50fe4e2 100644 --- a/tensorflow/core/framework/run_handler.cc +++ b/tensorflow/core/framework/run_handler.cc @@ -41,79 +41,52 @@ namespace { static constexpr int32 kMaxConcurrentHandlers = 128; // LINT.ThenChange(//tensorflow/core/framework/run_handler_test.cc) -// TODO(azaks): Refactor with thread:ThreadPool -class RunHandlerEnvironment { - typedef Thread EnvThread; - struct TaskImpl { - std::function f; - Context context; - uint64 trace_id; - }; - Env* const env_; - const ThreadOptions thread_options_; - const string name_; - - public: - struct Task { - std::unique_ptr f; - }; - - RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options, - const string& name) - : env_(env), thread_options_(thread_options), name_(name) {} - - EnvThread* CreateThread(std::function f) { - return env_->StartThread(thread_options_, name_, [=]() { - // Set the processor flag to flush denormals to zero. - port::ScopedFlushDenormal flush; - // Set the processor rounding mode to ROUND TO NEAREST. - port::ScopedSetRound round(FE_TONEAREST); - if (thread_options_.numa_node != port::kNUMANoAffinity) { - port::NUMASetThreadNodeAffinity(thread_options_.numa_node); - } - f(); - }); - } - - Task CreateTask(std::function f) { - uint64 id = 0; - if (tracing::EventCollector::IsEnabled()) { - id = tracing::GetUniqueArg(); - tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id); - } - return Task{ - std::unique_ptr(new TaskImpl{ - std::move(f), - Context(ContextKind::kThread), - id, - }), - }; - } - - void ExecuteTask(const Task& t) { - WithContext wc(t.f->context); - tracing::ScopedRegion region(tracing::EventCategory::kRunClosure, - t.f->trace_id); - t.f->f(); - } -}; - -typedef typename RunHandlerEnvironment::Task Task; +typedef typename internal::RunHandlerEnvironment::Task Task; typedef Eigen::RunQueue Queue; -// To reduce cache misses, we use a doubly-linked list of Waiter structs and -// queue them in LIFO order rather than the FIFO order used by a single -// condition variable. -struct Waiter { - Waiter() { - next = this; - prev = this; +} // namespace + +namespace internal { +RunHandlerEnvironment::RunHandlerEnvironment( + Env* env, const ThreadOptions& thread_options, const string& name) + : env_(env), thread_options_(thread_options), name_(name) {} + +RunHandlerEnvironment::EnvThread* RunHandlerEnvironment::CreateThread( + std::function f) { + return env_->StartThread(thread_options_, name_, [=]() { + // Set the processor flag to flush denormals to zero. + port::ScopedFlushDenormal flush; + // Set the processor rounding mode to ROUND TO NEAREST. + port::ScopedSetRound round(FE_TONEAREST); + if (thread_options_.numa_node != port::kNUMANoAffinity) { + port::NUMASetThreadNodeAffinity(thread_options_.numa_node); + } + f(); + }); +} + +RunHandlerEnvironment::Task RunHandlerEnvironment::CreateTask( + std::function f) { + uint64 id = 0; + if (tracing::EventCollector::IsEnabled()) { + id = tracing::GetUniqueArg(); + tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id); } - condition_variable cv; - mutex mu; - Waiter* next; - Waiter* prev; -}; + return Task{ + std::unique_ptr(new TaskImpl{ + std::move(f), + Context(ContextKind::kThread), + id, + }), + }; +} + +void RunHandlerEnvironment::ExecuteTask(const Task& t) { + WithContext wc(t.f->context); + tracing::ScopedRegion region(tracing::EventCategory::kRunClosure, + t.f->trace_id); + t.f->f(); +} void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex, int max_sleep_micros) { @@ -150,442 +123,359 @@ void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex, } } -class ThreadWorkSource { - public: - ThreadWorkSource() - : non_blocking_work_sharding_factor_( - static_cast(ParamFromEnvWithDefault( - "TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))), - non_blocking_work_queues_(non_blocking_work_sharding_factor_), - blocking_inflight_(0), - non_blocking_inflight_(0), - traceme_id_(0), - version_(0), - sub_thread_pool_waiter_(nullptr) { - queue_waiters_.next = &queue_waiters_; - queue_waiters_.prev = &queue_waiters_; - for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) { - non_blocking_work_queues_.emplace_back(new NonBlockingQueue()); +ThreadWorkSource::ThreadWorkSource() + : non_blocking_work_sharding_factor_( + static_cast(ParamFromEnvWithDefault( + "TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))), + non_blocking_work_queues_(non_blocking_work_sharding_factor_), + blocking_inflight_(0), + non_blocking_inflight_(0), + traceme_id_(0), + version_(0), + sub_thread_pool_waiter_(nullptr) { + queue_waiters_.next = &queue_waiters_; + queue_waiters_.prev = &queue_waiters_; + for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) { + non_blocking_work_queues_.emplace_back(new NonBlockingQueue()); + } +} + +ThreadWorkSource::~ThreadWorkSource() { + for (int i = 0; i < non_blocking_work_queues_.size(); ++i) { + delete non_blocking_work_queues_[i]; + } +} + +Task ThreadWorkSource::EnqueueTask(Task t, bool is_blocking) { + mutex* mu = nullptr; + Queue* task_queue = nullptr; + thread_local int64 closure_counter = 0; + + if (!is_blocking) { + int queue_index = ++closure_counter % non_blocking_work_sharding_factor_; + task_queue = &(non_blocking_work_queues_[queue_index]->queue); + mu = &non_blocking_work_queues_[queue_index]->queue_op_mu; + } else { + task_queue = &blocking_work_queue_; + mu = &blocking_queue_op_mu_; + } + + { + mutex_lock l(*mu); + // For a given queue, only one thread can call PushFront. + t = task_queue->PushFront(std::move(t)); + } + + Waiter* w = nullptr; + bool use_sub_thread_pool = + ParamFromEnvBoolWithDefault("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false); + + Waiter* waiter_queue; + mutex* waiter_queue_mu; + if (use_sub_thread_pool) { + // When we use multiple sub thread pools, free threads wait on sub + // thread pool waiting queues. Wake up threads from sub thread waiting + // queues. + // The waiting queues are defined at RunHandlerPool. + // Get the waiter_queue and corresponding mutex. Note, the thread work + // source may change afterwards if a new request comes or an old request + // finishes. + tf_shared_lock lock(run_handler_waiter_mu_); + waiter_queue = sub_thread_pool_waiter_; + waiter_queue_mu = sub_thread_pool_waiter_mu_; + } else { + waiter_queue = &queue_waiters_; + waiter_queue_mu = &waiters_mu_; + } + { + mutex_lock l(*waiter_queue_mu); + if (waiter_queue->next != waiter_queue) { + // Remove waiter from the LIFO queue + w = waiter_queue->next; + + CHECK(w->prev != w); // Crash OK. + CHECK(w->next != w); // Crash OK. + + w->next->prev = w->prev; + w->prev->next = w->next; + + // Use `w->next == &w` to indicate that the waiter has been removed + // from the queue. + w->next = w; + w->prev = w; + } + } + if (w != nullptr) { + // We call notify_one() without any locks, so we can miss notifications. + // The wake up logic is best effort and a thread will wake in short + // period of time in case a notification is missed. + w->cv.notify_one(); + } + VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from " + << traceme_id_.load(std::memory_order_relaxed); + return t; +} + +Task ThreadWorkSource::PopBlockingTask() { + return blocking_work_queue_.PopBack(); +} + +Task ThreadWorkSource::PopNonBlockingTask(int start_index, + bool search_from_all_queue) { + Task t; + unsigned sharding_factor = NonBlockingWorkShardingFactor(); + for (unsigned j = 0; j < sharding_factor; ++j) { + t = non_blocking_work_queues_[(start_index + j) % sharding_factor] + ->queue.PopBack(); + if (t.f) { + return t; + } + if (!search_from_all_queue) { + break; + } + } + return t; +} + +void ThreadWorkSource::WaitForWork(int max_sleep_micros) { + thread_local Waiter waiter; + WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros); +} + +int ThreadWorkSource::TaskQueueSize(bool is_blocking) { + if (is_blocking) { + return blocking_work_queue_.Size(); + } else { + unsigned total_size = 0; + for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) { + total_size += non_blocking_work_queues_[i]->queue.Size(); + } + return total_size; + } +} + +int64 ThreadWorkSource::GetTracemeId() { + return traceme_id_.load(std::memory_order_relaxed); +} + +void ThreadWorkSource::SetTracemeId(int64 value) { traceme_id_ = value; } + +void ThreadWorkSource::SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) { + { + tf_shared_lock lock(run_handler_waiter_mu_); + // Most of the request won't change sub pool for recomputation. + // Optimization for avoiding holding exclusive lock to reduce contention. + if (sub_thread_pool_waiter_ == waiter) { + return; + } + // If the current version is a newer version, no need to update. + if (version_ > version) { + return; } } - ~ThreadWorkSource() { - for (int i = 0; i < non_blocking_work_queues_.size(); ++i) { - delete non_blocking_work_queues_[i]; - } - } + mutex_lock l(run_handler_waiter_mu_); + sub_thread_pool_waiter_ = waiter; + sub_thread_pool_waiter_mu_ = mutex; + version_ = version; +} - Task EnqueueTask(Task t, bool is_blocking) { - mutex* mu = nullptr; - Queue* task_queue = nullptr; - thread_local int64 closure_counter = 0; +int64 ThreadWorkSource::GetInflightTaskCount(bool is_blocking) { + std::atomic* counter = + is_blocking ? &blocking_inflight_ : &non_blocking_inflight_; + return counter->load(std::memory_order_relaxed); +} - if (!is_blocking) { - int queue_index = ++closure_counter % non_blocking_work_sharding_factor_; - task_queue = &(non_blocking_work_queues_[queue_index]->queue); - mu = &non_blocking_work_queues_[queue_index]->queue_op_mu; - } else { - task_queue = &blocking_work_queue_; - mu = &blocking_queue_op_mu_; - } +void ThreadWorkSource::IncrementInflightTaskCount(bool is_blocking) { + std::atomic* counter = + is_blocking ? &blocking_inflight_ : &non_blocking_inflight_; + counter->fetch_add(1, std::memory_order_relaxed); +} +void ThreadWorkSource::DecrementInflightTaskCount(bool is_blocking) { + std::atomic* counter = + is_blocking ? &blocking_inflight_ : &non_blocking_inflight_; + counter->fetch_sub(1, std::memory_order_relaxed); +} + +unsigned ThreadWorkSource::NonBlockingWorkShardingFactor() { + return non_blocking_work_sharding_factor_; +} + +std::string ThreadWorkSource::ToString() { + return strings::StrCat("traceme_id = ", GetTracemeId(), + ", inter queue size = ", TaskQueueSize(true), + ", inter inflight = ", GetInflightTaskCount(true), + ", intra queue size = ", TaskQueueSize(false), + ", intra inflight = ", GetInflightTaskCount(false)); +} + +RunHandlerThreadPool::RunHandlerThreadPool( + int num_blocking_threads, int num_non_blocking_threads, Env* env, + const ThreadOptions& thread_options, const string& name, + Eigen::MaxSizeVector* waiters_mu, + Eigen::MaxSizeVector* queue_waiters) + : num_threads_(num_blocking_threads + num_non_blocking_threads), + num_blocking_threads_(num_blocking_threads), + num_non_blocking_threads_(num_non_blocking_threads), + thread_data_(num_threads_), + env_(env, thread_options, name), + name_(name), + waiters_mu_(waiters_mu), + queue_waiters_(queue_waiters), + use_sub_thread_pool_(ParamFromEnvBoolWithDefault( + "TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)), + num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault( + "TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", + std::vector({num_blocking_threads / 2, + num_blocking_threads - num_blocking_threads / 2}))), + sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault( + "TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", + std::vector({0, 0.4}))), + sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault( + "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", + std::vector({0.4, 1}))) { + thread_data_.resize(num_threads_); + VLOG(1) << "Creating RunHandlerThreadPool " << name << " with " + << num_blocking_threads_ << " blocking threads and " + << num_non_blocking_threads_ << " non-blocking threads."; +} + +RunHandlerThreadPool::~RunHandlerThreadPool() { + VLOG(1) << "Exiting RunHandlerThreadPool " << name_; + + cancelled_ = true; + for (size_t i = 0; i < thread_data_.size(); ++i) { { - mutex_lock l(*mu); - // For a given queue, only one thread can call PushFront. - t = task_queue->PushFront(std::move(t)); + mutex_lock l(thread_data_[i].mu); + thread_data_[i].sources_not_empty.notify_all(); } - - Waiter* w = nullptr; - bool use_sub_thread_pool = ParamFromEnvBoolWithDefault( - "TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false); - - Waiter* waiter_queue; - mutex* waiter_queue_mu; - if (use_sub_thread_pool) { - // When we use multiple sub thread pools, free threads wait on sub - // thread pool waiting queues. Wake up threads from sub thread waiting - // queues. - // The waiting queues are defined at RunHandlerPool. - // Get the waiter_queue and corresponding mutex. Note, the thread work - // source may change afterwards if a new request comes or an old request - // finishes. - tf_shared_lock lock(run_handler_waiter_mu_); - waiter_queue = sub_thread_pool_waiter_; - waiter_queue_mu = sub_thread_pool_waiter_mu_; - } else { - waiter_queue = &queue_waiters_; - waiter_queue_mu = &waiters_mu_; - } - - { - mutex_lock l(*waiter_queue_mu); - if (waiter_queue->next != waiter_queue) { - // Remove waiter from the LIFO queue - w = waiter_queue->next; - - CHECK(w->prev != w); - CHECK(w->next != w); - - w->next->prev = w->prev; - w->prev->next = w->next; - - // Use `w->next == &w` to indicate that the waiter has been removed - // from the queue. - w->next = w; - w->prev = w; - } - } - if (w != nullptr) { - // We call notify_one() without any locks, so we can miss notifications. - // The wake up logic is best effort and a thread will wake in short - // period of time in case a notification is missed. - w->cv.notify_one(); - } - VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from " - << traceme_id_.load(std::memory_order_relaxed); - return t; + thread_data_[i].thread.reset(); } +} - Task PopBlockingTask() { return blocking_work_queue_.PopBack(); } - - Task PopNonBlockingTask(int start_index, bool search_from_all_queue) { - Task t; - unsigned sharding_factor = NonBlockingWorkShardingFactor(); - for (unsigned j = 0; j < sharding_factor; ++j) { - t = non_blocking_work_queues_[(start_index + j) % sharding_factor] - ->queue.PopBack(); - if (t.f) { - return t; - } - if (!search_from_all_queue) { +void RunHandlerThreadPool::Start() { + cancelled_ = false; + int num_blocking_threads = num_blocking_threads_; + for (int i = 0; i < num_threads_; i++) { + int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1; + for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) { + if (i < num_threads_in_sub_thread_pool_[j]) { + sub_thread_pool_id = j; break; } } - return t; + thread_data_[i].sub_thread_pool_id = sub_thread_pool_id; + thread_data_[i].thread.reset( + env_.CreateThread([this, i, num_blocking_threads]() { + WorkerLoop(i, i < num_blocking_threads); + })); } +} - void WaitForWork(int max_sleep_micros) { - thread_local Waiter waiter; - WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros); +void RunHandlerThreadPool::StartOneThreadForTesting() { + cancelled_ = false; + thread_data_[0].sub_thread_pool_id = 0; + thread_data_[0].thread.reset( + env_.CreateThread([this]() { WorkerLoop(0, true); })); +} + +void RunHandlerThreadPool::AddWorkToQueue(ThreadWorkSource* tws, + bool is_blocking, + std::function fn) { + Task t = env_.CreateTask(std::move(fn)); + t = tws->EnqueueTask(std::move(t), is_blocking); + if (t.f) { + VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for " + << tws->GetTracemeId(); + env_.ExecuteTask(t); } +} - int TaskQueueSize(bool is_blocking) { - if (is_blocking) { - return blocking_work_queue_.Size(); - } else { - unsigned total_size = 0; - for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) { - total_size += non_blocking_work_queues_[i]->queue.Size(); - } - return total_size; - } +// TODO(donglin) Change the task steal order to be round-robin such that if +// an attempt to steal task from request i failed, then attempt to steal task +// from the next request in terms of the arrival time. This approach may +// provide better performance due to less lock retention. The drawback is that +// the profiler will be a bit harder to read. +void RunHandlerThreadPool::SetThreadWorkSources( + int tid, int start_request_idx, uint64 version, + const Eigen::MaxSizeVector& thread_work_sources) { + mutex_lock l(thread_data_[tid].mu); + if (version > thread_data_[tid].new_version) { + thread_data_[tid].new_version = version; + } else { + // A newer version is already updated. No need to update. + return; } - - int64 GetTracemeId() { return traceme_id_.load(std::memory_order_relaxed); } - - void SetTracemeId(int64 value) { traceme_id_ = value; } - - void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) { - { - tf_shared_lock lock(run_handler_waiter_mu_); - // Most of the request won't change sub pool for recomputation. - // Optimization for avoiding holding exclusive lock to reduce contention. - if (sub_thread_pool_waiter_ == waiter) { - return; - } - // If the current version is a newer version, no need to update. - if (version_ > version) { - return; - } - } - - mutex_lock l(run_handler_waiter_mu_); - sub_thread_pool_waiter_ = waiter; - sub_thread_pool_waiter_mu_ = mutex; - version_ = version; - } - - int64 GetInflightTaskCount(bool is_blocking) { - std::atomic* counter = - is_blocking ? &blocking_inflight_ : &non_blocking_inflight_; - return counter->load(std::memory_order_relaxed); - } - - void IncrementInflightTaskCount(bool is_blocking) { - std::atomic* counter = - is_blocking ? &blocking_inflight_ : &non_blocking_inflight_; - counter->fetch_add(1, std::memory_order_relaxed); - } - - void DecrementInflightTaskCount(bool is_blocking) { - std::atomic* counter = - is_blocking ? &blocking_inflight_ : &non_blocking_inflight_; - counter->fetch_sub(1, std::memory_order_relaxed); - } - - unsigned NonBlockingWorkShardingFactor() { - return non_blocking_work_sharding_factor_; - } - - std::string ToString() { - return strings::StrCat("traceme_id = ", GetTracemeId(), - ", inter queue size = ", TaskQueueSize(true), - ", inter inflight = ", GetInflightTaskCount(true), - ", intra queue size = ", TaskQueueSize(false), - ", intra inflight = ", GetInflightTaskCount(false)); - } - - private: - struct NonBlockingQueue { - mutex queue_op_mu; - char pad[128]; - Queue queue; - }; - - int32 non_blocking_work_sharding_factor_; - Eigen::MaxSizeVector non_blocking_work_queues_; - - std::atomic blocking_inflight_; - std::atomic non_blocking_inflight_; - - Queue blocking_work_queue_; - mutex blocking_queue_op_mu_; - char pad_[128]; - mutex waiters_mu_; - Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_); - std::atomic traceme_id_; - - mutex run_handler_waiter_mu_; - uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_); - mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_); - Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_); -}; - -class RunHandlerThreadPool { - public: - struct PerThread { - constexpr PerThread() : pool(nullptr), thread_id(-1) {} - RunHandlerThreadPool* pool; // Parent pool, or null for normal threads. - int thread_id; // Worker thread index in pool. - }; - - RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads, - Env* env, const ThreadOptions& thread_options, - const string& name, - Eigen::MaxSizeVector* waiters_mu, - Eigen::MaxSizeVector* queue_waiters) - : num_threads_(num_blocking_threads + num_non_blocking_threads), - num_blocking_threads_(num_blocking_threads), - num_non_blocking_threads_(num_non_blocking_threads), - thread_data_(num_threads_), - env_(env, thread_options, name), - name_(name), - waiters_mu_(waiters_mu), - queue_waiters_(queue_waiters), - use_sub_thread_pool_(ParamFromEnvBoolWithDefault( - "TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)), - num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault( - "TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", - std::vector( - {num_blocking_threads / 2, - num_blocking_threads - num_blocking_threads / 2}))), - sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault( - "TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", - std::vector({0, 0.4}))), - sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault( - "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", - std::vector({0.4, 1}))) { - VLOG(1) << "Creating RunHandlerThreadPool " << name << " with " - << num_blocking_threads_ << " blocking threads and " - << num_non_blocking_threads_ << " non-blocking threads."; - } - - ~RunHandlerThreadPool() { - VLOG(1) << "Exiting RunHandlerThreadPool " << name_; - - cancelled_ = true; - for (size_t i = 0; i < thread_data_.size(); ++i) { - { - mutex_lock l(thread_data_[i].mu); - thread_data_[i].sources_not_empty.notify_all(); - } - thread_data_[i].thread.reset(); - } - } - - void Start() { - cancelled_ = false; - thread_data_.resize(num_threads_); - int num_blocking_threads = num_blocking_threads_; - for (int i = 0; i < num_threads_; i++) { - int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1; - for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) { - if (i < num_threads_in_sub_thread_pool_[j]) { - sub_thread_pool_id = j; - break; - } - } - thread_data_[i].sub_thread_pool_id = sub_thread_pool_id; - thread_data_[i].thread.reset( - env_.CreateThread([this, i, num_blocking_threads]() { - WorkerLoop(i, i < num_blocking_threads); - })); - } - } - - void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking, - std::function fn) { - Task t = env_.CreateTask(std::move(fn)); - t = tws->EnqueueTask(std::move(t), is_blocking); - if (t.f) { - VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for " - << tws->GetTracemeId(); - env_.ExecuteTask(t); - } - } - - // Set work queues from which the thread 'tid' can steal its work. - // The request with start_request_idx will be attempted first. Other requests - // will be attempted in FIFO order based on their arrival time. - - // TODO(donglin) Change the task steal order to be round-robin such that if - // an attempt to steal task from request i failed, then attempt to steal task - // from the next request in terms of the arrival time. This approach may - // provide better performance due to less lock retention. The drawback is that - // the profiler will be a bit harder to read. - void SetThreadWorkSources( - int tid, int start_request_idx, uint64 version, - const Eigen::MaxSizeVector& thread_work_sources) { - mutex_lock l(thread_data_[tid].mu); - if (version > thread_data_[tid].new_version) { - thread_data_[tid].new_version = version; - } else { - // A newer version is already updated. No need to update. - return; - } - thread_data_[tid].new_thread_work_sources->resize(0); - - if (use_sub_thread_pool_) { - for (int i = 0; i < thread_work_sources.size(); ++i) { - thread_data_[tid].new_thread_work_sources->emplace_back( - thread_work_sources[i]); - } - } else { + thread_data_[tid].new_thread_work_sources->resize(0); + if (use_sub_thread_pool_) { + for (int i = 0; i < thread_work_sources.size(); ++i) { thread_data_[tid].new_thread_work_sources->emplace_back( - thread_work_sources[start_request_idx]); - // The number of shards for the queue. Threads in each shard will - // prioritize different thread_work_sources. Increase the number of shards - // could decrease the contention in the queue. For example, when - // num_shards == 1: thread_work_sources are ordered as start_request_idx, - // 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2: - // thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3, - // 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2, - // 4... for the other half of the threads. - int num_shards = - ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1); - int token = tid % num_shards; - for (int i = 0; i < num_shards; ++i) { - for (int j = token; j < thread_work_sources.size(); j += num_shards) { - if (j != start_request_idx) { - thread_data_[tid].new_thread_work_sources->emplace_back( - thread_work_sources[j]); - } + thread_work_sources[i]); + } + } else { + thread_data_[tid].new_thread_work_sources->emplace_back( + thread_work_sources[start_request_idx]); + // The number of shards for the queue. Threads in each shard will + // prioritize different thread_work_sources. Increase the number of shards + // could decrease the contention in the queue. For example, when + // num_shards == 1: thread_work_sources are ordered as start_request_idx, + // 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2: + // thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3, + // 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2, + // 4... for the other half of the threads. + int num_shards = ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1); + int token = tid % num_shards; + for (int i = 0; i < num_shards; ++i) { + for (int j = token; j < thread_work_sources.size(); j += num_shards) { + if (j != start_request_idx) { + thread_data_[tid].new_thread_work_sources->emplace_back( + thread_work_sources[j]); } - token = (token + 1) % num_shards; } - thread_data_[tid].sources_not_empty.notify_all(); + token = (token + 1) % num_shards; } + thread_data_[tid].sources_not_empty.notify_all(); } +} - PerThread* GetPerThread() { - thread_local PerThread per_thread_; - PerThread* pt = &per_thread_; - return pt; +RunHandlerThreadPool::PerThread* RunHandlerThreadPool::GetPerThread() { + thread_local RunHandlerThreadPool::PerThread per_thread_; + RunHandlerThreadPool::PerThread* pt = &per_thread_; + return pt; +} + +int RunHandlerThreadPool::CurrentThreadId() const { + const PerThread* pt = const_cast(this)->GetPerThread(); + if (pt->pool == this) { + return pt->thread_id; + } else { + return -1; } +} - int CurrentThreadId() const { - const PerThread* pt = - const_cast(this)->GetPerThread(); - if (pt->pool == this) { - return pt->thread_id; - } else { - return -1; - } - } +int RunHandlerThreadPool::NumThreads() const { return num_threads_; } - int NumThreads() const { return num_threads_; } +int RunHandlerThreadPool::NumBlockingThreads() const { + return num_blocking_threads_; +} - int NumBlockingThreads() const { return num_blocking_threads_; } +int RunHandlerThreadPool::NumNonBlockingThreads() const { + return num_non_blocking_threads_; +} - int NumNonBlockingThreads() const { return num_non_blocking_threads_; } - - void WorkerLoop(int thread_id, bool may_steal_blocking_work); - - // Search tasks from Requets range searching_range_start to - // searching_range_end. If there is no tasks in the search range and - // may_steal_blocking_work is true, then search from all requests. - Task FindTask( - int searching_range_start, int searching_range_end, int thread_id, - int sub_thread_pool_id, int max_blocking_inflight, - bool may_steal_blocking_work, - const Eigen::MaxSizeVector& thread_work_sources, - bool* task_from_blocking_queue, ThreadWorkSource** tws); - - void WaitForWork(bool is_blocking, int thread_id, - int32 max_blocking_inflight); - - void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id); - - private: - struct ThreadData { - ThreadData() - : new_version(0), - current_index(0), - new_thread_work_sources(new Eigen::MaxSizeVector( - static_cast(ParamFromEnvWithDefault( - "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", - kMaxConcurrentHandlers)))), - current_version(0), - current_thread_work_sources( - new Eigen::MaxSizeVector( - static_cast(ParamFromEnvWithDefault( - "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", - kMaxConcurrentHandlers)))) {} - mutex mu; - uint64 new_version; - condition_variable sources_not_empty; - std::unique_ptr thread; - int current_index; - std::unique_ptr> - new_thread_work_sources TF_GUARDED_BY(mu); - - uint64 current_version; - // Should only be accessed by one thread. - std::unique_ptr> - current_thread_work_sources; - - int sub_thread_pool_id; - }; - - const int num_threads_; - const int num_blocking_threads_; - const int num_non_blocking_threads_; - Eigen::MaxSizeVector thread_data_; - RunHandlerEnvironment env_; - std::atomic cancelled_; - string name_; - Eigen::MaxSizeVector* waiters_mu_; - Eigen::MaxSizeVector* queue_waiters_; - - bool use_sub_thread_pool_; - std::vector num_threads_in_sub_thread_pool_; - - // Threads in each sub thread pool will search tasks from the given - // start_request_percentage to end_request_percentage in a round robin - // fashion. - std::vector sub_thread_pool_start_request_percentage_; - std::vector sub_thread_pool_end_request_percentage_; -}; +RunHandlerThreadPool::ThreadData::ThreadData() + : new_version(0), + current_index(0), + new_thread_work_sources( + new Eigen::MaxSizeVector(static_cast( + ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", + kMaxConcurrentHandlers)))), + current_version(0), + current_thread_work_sources( + new Eigen::MaxSizeVector(static_cast( + ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", + kMaxConcurrentHandlers)))) {} Task RunHandlerThreadPool::FindTask( int searching_range_start, int searching_range_end, int thread_id, @@ -597,10 +487,9 @@ Task RunHandlerThreadPool::FindTask( int current_index = thread_data_[thread_id].current_index; *task_from_blocking_queue = false; - // TODO(chaox): Change the search algorithm from round robin to random - // walk. for (int i = 0; i < searching_range_end - searching_range_start; ++i) { - if (current_index >= searching_range_end) { + if (current_index >= searching_range_end || + current_index < searching_range_start) { current_index = searching_range_start; } *tws = thread_work_sources[current_index]; @@ -821,7 +710,7 @@ void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id, tws->WaitForWork(kMaxSleepMicros); } -} // namespace +} // namespace internal // Contains the concrete implementation of the RunHandler. // Externally visible RunHandler class simply forwards the work to this one. @@ -847,7 +736,7 @@ class RunHandler::Impl { RunHandlerPool::Impl* pool_impl() { return pool_impl_; } - ThreadWorkSource* tws() { return &tws_; } + internal::ThreadWorkSource* tws() { return &tws_; } int64 priority() { return options_.priority(); } @@ -869,7 +758,7 @@ class RunHandler::Impl { uint64 start_time_us_; int64 step_id_; std::unique_ptr thread_pool_interface_; - ThreadWorkSource tws_; + internal::ThreadWorkSource tws_; RunOptions::Experimental::RunHandlerPoolOptions options_; }; @@ -885,7 +774,7 @@ class RunHandlerPool::Impl { ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)), queue_waiters_( ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)), - run_handler_thread_pool_(new RunHandlerThreadPool( + run_handler_thread_pool_(new internal::RunHandlerThreadPool( num_inter_op_threads, num_intra_op_threads, Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu_, &queue_waiters_)), @@ -924,7 +813,7 @@ class RunHandlerPool::Impl { run_handler_thread_pool_.reset(); } - RunHandlerThreadPool* run_handler_thread_pool() { + internal::RunHandlerThreadPool* run_handler_thread_pool() { return run_handler_thread_pool_.get(); } @@ -936,10 +825,11 @@ class RunHandlerPool::Impl { int64 step_id, int64 timeout_in_ms, const RunOptions::Experimental::RunHandlerPoolOptions& options) TF_LOCKS_EXCLUDED(mu_) { - thread_local std::unique_ptr> + thread_local std::unique_ptr< + Eigen::MaxSizeVector> thread_work_sources = - std::unique_ptr>( - new Eigen::MaxSizeVector( + std::unique_ptr>( + new Eigen::MaxSizeVector( static_cast(ParamFromEnvWithDefault( "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", kMaxConcurrentHandlers)))); @@ -1035,7 +925,8 @@ class RunHandlerPool::Impl { private: void RecomputePoolStats( int num_active_requests, uint64 version, - const Eigen::MaxSizeVector& thread_work_sources); + const Eigen::MaxSizeVector& + thread_work_sources); void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -1046,9 +937,9 @@ class RunHandlerPool::Impl { const int max_handlers_; Eigen::MaxSizeVector waiters_mu_; - Eigen::MaxSizeVector queue_waiters_; + Eigen::MaxSizeVector queue_waiters_; - std::unique_ptr run_handler_thread_pool_; + std::unique_ptr run_handler_thread_pool_; // Thread compatible part used only by lock under RunHandlerPool. // Handlers are sorted by start time. // TODO(azaks): sort by the remaining latency budget. @@ -1070,7 +961,8 @@ class RunHandlerPool::Impl { void RunHandlerPool::Impl::RecomputePoolStats( int num_active_requests, uint64 version, - const Eigen::MaxSizeVector& thread_work_sources) { + const Eigen::MaxSizeVector& + thread_work_sources) { if (num_active_requests == 0) return; int sub_thread_pool_id = 0; diff --git a/tensorflow/core/framework/run_handler.h b/tensorflow/core/framework/run_handler.h index 13e31c1b93b..529c3deda3a 100644 --- a/tensorflow/core/framework/run_handler.h +++ b/tensorflow/core/framework/run_handler.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/platform/context.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/protobuf/config.pb.h" @@ -106,6 +107,208 @@ class RunHandler { Impl* impl_; // NOT OWNED. }; +namespace internal { + +// TODO(azaks): Refactor with thread:ThreadPool +class RunHandlerEnvironment { + typedef Thread EnvThread; + struct TaskImpl { + std::function f; + Context context; + uint64 trace_id; + }; + Env* const env_; + const ThreadOptions thread_options_; + const string name_; + + public: + struct Task { + std::unique_ptr f; + }; + + RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options, + const string& name); + + EnvThread* CreateThread(std::function f); + + Task CreateTask(std::function f); + + void ExecuteTask(const Task& t); +}; + +typedef typename RunHandlerEnvironment::Task Task; +typedef Eigen::RunQueue Queue; + +// To reduce cache misses, we use a doubly-linked list of Waiter structs and +// queue them in LIFO order rather than the FIFO order used by a single +// condition variable. +struct Waiter { + Waiter() { + next = this; + prev = this; + } + condition_variable cv; + mutex mu; + Waiter* next; + Waiter* prev; +}; + +class ThreadWorkSource { + public: + ThreadWorkSource(); + + ~ThreadWorkSource(); + + Task EnqueueTask(Task t, bool is_blocking); + + Task PopBlockingTask(); + + Task PopNonBlockingTask(int start_index, bool search_from_all_queue); + + void WaitForWork(int max_sleep_micros); + + int TaskQueueSize(bool is_blocking); + + int64 GetTracemeId(); + + void SetTracemeId(int64 value); + + void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex); + + int64 GetInflightTaskCount(bool is_blocking); + + void IncrementInflightTaskCount(bool is_blocking); + + void DecrementInflightTaskCount(bool is_blocking); + + unsigned NonBlockingWorkShardingFactor(); + + std::string ToString(); + + private: + struct NonBlockingQueue { + mutex queue_op_mu; + char pad[128]; + Queue queue; + }; + + int32 non_blocking_work_sharding_factor_; + Eigen::MaxSizeVector non_blocking_work_queues_; + + std::atomic blocking_inflight_; + std::atomic non_blocking_inflight_; + + Queue blocking_work_queue_; + mutex blocking_queue_op_mu_; + char pad_[128]; + mutex waiters_mu_; + Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_); + std::atomic traceme_id_; + + mutex run_handler_waiter_mu_; + uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_); + mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_); + Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_); +}; + +class RunHandlerThreadPool { + public: + struct PerThread { + constexpr PerThread() : pool(nullptr), thread_id(-1) {} + RunHandlerThreadPool* pool; // Parent pool, or null for normal threads. + int thread_id; // Worker thread index in pool. + }; + + RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads, + Env* env, const ThreadOptions& thread_options, + const string& name, + Eigen::MaxSizeVector* waiters_mu, + Eigen::MaxSizeVector* queue_waiters); + + ~RunHandlerThreadPool(); + + void Start(); + + void StartOneThreadForTesting(); + + void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking, + std::function fn); + + // Set work queues from which the thread 'tid' can steal its work. + // The request with start_request_idx will be attempted first. Other requests + // will be attempted in FIFO order based on their arrival time. + void SetThreadWorkSources( + int tid, int start_request_idx, uint64 version, + const Eigen::MaxSizeVector& thread_work_sources); + + PerThread* GetPerThread(); + + int CurrentThreadId() const; + + int NumThreads() const; + + int NumBlockingThreads() const; + + int NumNonBlockingThreads() const; + + void WorkerLoop(int thread_id, bool may_steal_blocking_work); + + // Search tasks from Requets range searching_range_start to + // searching_range_end. If there is no tasks in the search range and + // may_steal_blocking_work is true, then search from all requests. + Task FindTask( + int searching_range_start, int searching_range_end, int thread_id, + int sub_thread_pool_id, int max_blocking_inflight, + bool may_steal_blocking_work, + const Eigen::MaxSizeVector& thread_work_sources, + bool* task_from_blocking_queue, ThreadWorkSource** tws); + + void WaitForWork(bool is_blocking, int thread_id, + int32 max_blocking_inflight); + + void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id); + + private: + struct ThreadData { + ThreadData(); + mutex mu; + uint64 new_version; + condition_variable sources_not_empty; + std::unique_ptr thread; + int current_index; + std::unique_ptr> + new_thread_work_sources TF_GUARDED_BY(mu); + + uint64 current_version; + // Should only be accessed by one thread. + std::unique_ptr> + current_thread_work_sources; + + int sub_thread_pool_id; + }; + + const int num_threads_; + const int num_blocking_threads_; + const int num_non_blocking_threads_; + Eigen::MaxSizeVector thread_data_; + internal::RunHandlerEnvironment env_; + std::atomic cancelled_; + string name_; + Eigen::MaxSizeVector* waiters_mu_; + Eigen::MaxSizeVector* queue_waiters_; + + bool use_sub_thread_pool_; + std::vector num_threads_in_sub_thread_pool_; + + // Threads in each sub thread pool will search tasks from the given + // start_request_percentage to end_request_percentage in a round robin + // fashion. + std::vector sub_thread_pool_start_request_percentage_; + std::vector sub_thread_pool_end_request_percentage_; +}; + +} // namespace internal + } // end namespace tensorflow. #endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ diff --git a/tensorflow/core/framework/run_handler_test.cc b/tensorflow/core/framework/run_handler_test.cc index 14b0a559641..d849bf0f17f 100644 --- a/tensorflow/core/framework/run_handler_test.cc +++ b/tensorflow/core/framework/run_handler_test.cc @@ -113,6 +113,476 @@ TEST(RunHandlerUtilTest, PrioritySchedulingTest) { EXPECT_EQ(sorted_active_list[3], 1); } +TEST(RunHandlerThreadPool, EnqueueTask) { + Eigen::MaxSizeVector waiters_mu(2); + waiters_mu.resize(2); + Eigen::MaxSizeVector waiters(2); + waiters.resize(2); + internal::RunHandlerThreadPool run_handler_thread_pool( + /*num_blocking_threads=*/0, /*num_non_blocking_threads=*/0, + Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu, + &waiters); + internal::ThreadWorkSource tws; + + int result = 0; + std::function fn = [&result] { result = 1; }; + std::function fn2 = [&result] { result = 2; }; + run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/true, fn); + EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/true), 1); + run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/true, fn2); + EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/true), 2); + tws.PopBlockingTask().f->f(); + EXPECT_EQ(result, 1); + tws.PopBlockingTask().f->f(); + EXPECT_EQ(result, 2); + + run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/false, fn); + EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/false), 1); + run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/false, fn2); + EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/false), 2); + tws.PopNonBlockingTask(0, true).f->f(); + EXPECT_EQ(result, 1); + tws.PopNonBlockingTask(0, true).f->f(); + EXPECT_EQ(result, 2); +} + +TEST(RunHandlerThreadPool, FindTask) { + Eigen::MaxSizeVector waiters_mu(2); + waiters_mu.resize(2); + Eigen::MaxSizeVector waiters(2); + waiters.resize(2); + internal::RunHandlerThreadPool run_handler_thread_pool( + /*num_blocking_threads=*/1, /*num_non_blocking_threads=*/0, + Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu, + &waiters); + + Eigen::MaxSizeVector thread_work_sources(5); + thread_work_sources.resize(5); + for (int i = 0; i < 5; ++i) { + thread_work_sources[i] = new internal::ThreadWorkSource(); + } + + { + // The thread should search the task following round robin fashion. + int result = -1; + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2], + /*is_blocking=*/true, + [&result] { result = 2; }); + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2], + /*is_blocking=*/true, + [&result] { result = 2; }); + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3], + /*is_blocking=*/true, + [&result] { result = 3; }); + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3], + /*is_blocking=*/true, + [&result] { result = 3; }); + + const auto find_blocking_task_from_all_handlers = + [&](bool* task_from_blocking_queue, internal::Task* t) { + internal::ThreadWorkSource* tws; + *t = run_handler_thread_pool.FindTask( + /*searching_range_start=*/0, /*searching_range_end=*/5, + /*thread_id=*/0, + /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10, + /*may_steal_blocking_work=*/true, thread_work_sources, + task_from_blocking_queue, &tws); + }; + bool task_from_blocking_queue; + internal::Task t; + find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t); + EXPECT_EQ(task_from_blocking_queue, true); + t.f->f(); + EXPECT_EQ(result, 2); + + find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t); + EXPECT_EQ(task_from_blocking_queue, true); + t.f->f(); + EXPECT_EQ(result, 3); + + find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t); + EXPECT_EQ(task_from_blocking_queue, true); + t.f->f(); + EXPECT_EQ(result, 2); + + find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t); + EXPECT_EQ(task_from_blocking_queue, true); + t.f->f(); + EXPECT_EQ(result, 3); + } + + { + // Task out of searching range cannot be found. + int result = -1; + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3], + /*is_blocking=*/true, + [&result] { result = 3; }); + + const auto find_blocking_task_from_range = + [&](bool* task_from_blocking_queue, internal::Task* t, int range_start, + int range_end) { + internal::ThreadWorkSource* tws; + *t = run_handler_thread_pool.FindTask( + range_start, range_end, + /*thread_id=*/0, + /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10, + /*may_steal_blocking_work=*/true, thread_work_sources, + task_from_blocking_queue, &tws); + }; + + bool task_from_blocking_queue; + internal::Task t; + find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 3); + EXPECT_EQ(t.f, nullptr); + + // Clean up the queue. + find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5); + } + + { + // The thread should search from start range if the currrent index is + // smaller. + int result = -1; + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2], + /*is_blocking=*/true, + [&result] { result = 2; }); + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3], + /*is_blocking=*/true, + [&result] { result = 3; }); + + const auto find_blocking_task_from_range = + [&](bool* task_from_blocking_queue, internal::Task* t, int range_start, + int range_end) { + internal::ThreadWorkSource* tws; + *t = run_handler_thread_pool.FindTask( + range_start, range_end, + /*thread_id=*/0, + /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10, + /*may_steal_blocking_work=*/true, thread_work_sources, + task_from_blocking_queue, &tws); + }; + bool task_from_blocking_queue; + internal::Task t; + find_blocking_task_from_range(&task_from_blocking_queue, &t, 3, 5); + EXPECT_EQ(task_from_blocking_queue, true); + t.f->f(); + EXPECT_EQ(result, 3); + + find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5); + EXPECT_EQ(task_from_blocking_queue, true); + t.f->f(); + EXPECT_EQ(result, 2); + } + + { + // The thread should search within the range even if the current index + // is larger than searching_range_end; + int result = -1; + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2], + /*is_blocking=*/true, + [&result] { result = 2; }); + + const auto find_blocking_task_from_range = + [&](bool* task_from_blocking_queue, internal::Task* t, int range_start, + int range_end) { + internal::ThreadWorkSource* tws; + *t = run_handler_thread_pool.FindTask( + range_start, range_end, + /*thread_id=*/0, + /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10, + /*may_steal_blocking_work=*/true, thread_work_sources, + task_from_blocking_queue, &tws); + }; + bool task_from_blocking_queue; + // Make the current index to be 3. + internal::Task t; + find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5); + EXPECT_EQ(task_from_blocking_queue, true); + t.f->f(); + EXPECT_EQ(result, 2); + + // Search in a smaller range. + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2], + /*is_blocking=*/true, + [&result] { result = 2; }); + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3], + /*is_blocking=*/true, + [&result] { result = 3; }); + find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 3); + EXPECT_EQ(task_from_blocking_queue, true); + t.f->f(); + EXPECT_EQ(result, 2); + + // Clean up the queue. + find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5); + EXPECT_EQ(task_from_blocking_queue, true); + t.f->f(); + EXPECT_EQ(result, 3); + } + + { + // We prefer blocking task for blocking threads. + int result = -1; + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2], + /*is_blocking=*/false, + [&result] { result = 2; }); + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2], + /*is_blocking=*/true, + [&result] { result = 2; }); + const auto blocking_thread_find_task_from_all_handler = + [&](bool* task_from_blocking_queue, internal::Task* t) { + internal::ThreadWorkSource* tws; + *t = run_handler_thread_pool.FindTask( + /*searching_range_start=*/0, /*searching_range_end=*/5, + /*thread_id=*/0, + /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10, + /*may_steal_blocking_work=*/true, thread_work_sources, + task_from_blocking_queue, &tws); + }; + bool task_from_blocking_queue; + internal::Task t; + blocking_thread_find_task_from_all_handler(&task_from_blocking_queue, &t); + EXPECT_EQ(task_from_blocking_queue, true); + t.f->f(); + EXPECT_EQ(result, 2); + + blocking_thread_find_task_from_all_handler(&task_from_blocking_queue, &t); + EXPECT_EQ(task_from_blocking_queue, false); + t.f->f(); + EXPECT_EQ(result, 2); + } + + { + // Nonblocking threads can only pick up non-blocking task. + int result = -1; + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2], + /*is_blocking=*/false, + [&result] { result = 2; }); + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2], + /*is_blocking=*/true, + [&result] { result = 2; }); + + const auto find_task_from_all_handler = [&](bool* task_from_blocking_queue, + internal::Task* t, + bool is_blocking_thread) { + internal::ThreadWorkSource* tws; + *t = run_handler_thread_pool.FindTask( + /*searching_range_start=*/0, /*searching_range_end=*/5, + /*thread_id=*/0, + /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10, + is_blocking_thread, thread_work_sources, task_from_blocking_queue, + &tws); + }; + bool task_from_blocking_queue; + internal::Task t; + find_task_from_all_handler(&task_from_blocking_queue, &t, + /*is_blocking_thread=*/false); + EXPECT_EQ(task_from_blocking_queue, false); + t.f->f(); + EXPECT_EQ(result, 2); + + find_task_from_all_handler(&task_from_blocking_queue, &t, + /*is_blocking_thread=*/false); + EXPECT_EQ(t.f, nullptr); + + // Clean up the queue. + find_task_from_all_handler(&task_from_blocking_queue, &t, + /*is_blocking_thread=*/true); + } + + { + // There is a limit for max_blocking_inflight requests. + int result = -1; + run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2], + /*is_blocking=*/true, + [&result] { result = 2; }); + + const auto find_task_from_all_handler = [&](bool* task_from_blocking_queue, + internal::Task* t, + bool is_blocking_thread) { + internal::ThreadWorkSource* tws; + *t = run_handler_thread_pool.FindTask( + /*searching_range_start=*/0, /*searching_range_end=*/5, + /*thread_id=*/0, + /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10, + is_blocking_thread, thread_work_sources, task_from_blocking_queue, + &tws); + }; + + bool task_from_blocking_queue; + internal::Task t; + find_task_from_all_handler(&task_from_blocking_queue, &t, + /*is_blocking_thread=*/false); + EXPECT_EQ(task_from_blocking_queue, false); + EXPECT_EQ(t.f, nullptr); + + // Clean up the queue. + find_task_from_all_handler(&task_from_blocking_queue, &t, + /*is_blocking_thread=*/true); + } + + for (int i = 0; i < 5; ++i) { + delete thread_work_sources[i]; + } +} + +TEST(RunHandlerThreadPool, RoundRobinExecution) { + // Set up environment for 1 sub thread pool. + setenv("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", "true", true); + setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "1", true); + setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", "0", true); + setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", "1", true); + + Eigen::MaxSizeVector waiters_mu(1); + waiters_mu.resize(1); + Eigen::MaxSizeVector waiters(1); + waiters.resize(1); + internal::RunHandlerThreadPool* run_handler_thread_pool = + new internal::RunHandlerThreadPool( + /*num_blocking_threads=*/1, /*num_non_blocking_threads=*/0, + Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu, + &waiters); + Eigen::MaxSizeVector thread_work_sources(3); + thread_work_sources.resize(3); + internal::ThreadWorkSource tws[3]; + for (int i = 0; i < 3; ++i) { + tws[i].SetWaiter(1, &waiters[0], &waiters_mu[0]); + thread_work_sources[i] = &tws[i]; + } + + int result = 0; + mutex mu; + bool ok_to_execute = false; + bool ok_to_validate = false; + condition_variable function_start; + condition_variable function_end; + std::vector> fns; + for (int i = 0; i < 3; ++i) { + fns.push_back([&result, &mu, &function_start, &function_end, &ok_to_execute, + &ok_to_validate, i] { + mutex_lock l(mu); + while (!ok_to_execute) { + function_start.wait(l); + } + result = i; + ok_to_execute = false; + ok_to_validate = true; + function_end.notify_one(); + }); + run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true, + fns[i]); + run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true, + fns[i]); + } + run_handler_thread_pool->Start(); + run_handler_thread_pool->SetThreadWorkSources( + /*tid=*/0, /*start_request_idx=*/0, /*version=*/1, thread_work_sources); + + // Validate the execution should be roundrobin. + mutex_lock l(mu); + for (int round = 0; round < 2; ++round) { + for (int i = 0; i < 3; ++i) { + ok_to_execute = true; + function_start.notify_one(); + while (!ok_to_validate) { + function_end.wait(l); + } + ok_to_validate = false; + EXPECT_EQ(result, i); + } + } + + delete run_handler_thread_pool; +} + +TEST(RunHandlerThreadPool, MultipleSubThreadPool) { + // Set up environment for 2 sub thread pools. + setenv("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", "true", true); + setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "2", true); + setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", "0,0.5", + true); + setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", "0.5,1", + true); + + Eigen::MaxSizeVector waiters_mu(2); + waiters_mu.resize(2); + Eigen::MaxSizeVector waiters(2); + waiters.resize(2); + internal::RunHandlerThreadPool* run_handler_thread_pool = + new internal::RunHandlerThreadPool( + /*num_blocking_threads=*/2, /*num_non_blocking_threads=*/0, + Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu, + &waiters); + Eigen::MaxSizeVector thread_work_sources(4); + thread_work_sources.resize(4); + internal::ThreadWorkSource tws[4]; + for (int i = 0; i < 4; ++i) { + tws[i].SetWaiter(1, &waiters[i / 2], &waiters_mu[i / 2]); + thread_work_sources[i] = &tws[i]; + } + + int result = 0; + mutex mu; + bool ok_to_execute = false; + bool ok_to_validate = false; + condition_variable function_start; + condition_variable function_end; + + std::vector> fns; + for (int i = 0; i < 4; ++i) { + fns.push_back([&result, &mu, &function_start, &function_end, &ok_to_execute, + &ok_to_validate, i] { + mutex_lock l(mu); + while (!ok_to_execute) { + function_start.wait(l); + } + result = i; + ok_to_execute = false; + ok_to_validate = true; + function_end.notify_one(); + }); + run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true, + fns[i]); + run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true, + fns[i]); + } + run_handler_thread_pool->StartOneThreadForTesting(); + run_handler_thread_pool->SetThreadWorkSources( + /*tid=*/0, /*start_request_idx=*/0, /*version=*/1, thread_work_sources); + run_handler_thread_pool->SetThreadWorkSources( + /*tid=*/1, /*start_request_idx=*/0, /*version=*/1, thread_work_sources); + + // Pick task from the given sub thread pool requests in a round robin fashion. + mutex_lock l(mu); + for (int round = 0; round < 2; ++round) { + for (int i = 0; i < 2; ++i) { + ok_to_execute = true; + function_start.notify_one(); + while (!ok_to_validate) { + function_end.wait(l); + } + ok_to_validate = false; + EXPECT_EQ(result, i); + } + } + + // Pick task from any task if there is no tasks from the requests in the sub + // thread pool. + for (int i = 0; i < 2; ++i) { + for (int round = 0; round < 2; ++round) { + ok_to_execute = true; + function_start.notify_one(); + while (!ok_to_validate) { + function_end.wait(l); + } + ok_to_validate = false; + EXPECT_EQ(result, i + 2); + } + } + + delete run_handler_thread_pool; +} + SessionOptions DefaultSessionOptions() { SessionOptions options; (*options.config.mutable_device_count())["CPU"] = 2;