From bfce5bae88e01c080c8cc9224b439c4dd26c506b Mon Sep 17 00:00:00 2001 From: Chao Xie Date: Thu, 7 Nov 2019 21:34:49 -0800 Subject: [PATCH] Introduce sub thread pools to run handler thread pool. PiperOrigin-RevId: 279237293 Change-Id: I8f1945105d27d1ed06eee8bb914bfaa06fec6c1f --- tensorflow/core/BUILD | 7 + tensorflow/core/framework/run_handler.cc | 439 ++++++++++++++---- tensorflow/core/framework/run_handler_test.cc | 133 ++++++ tensorflow/core/framework/run_handler_util.cc | 49 ++ tensorflow/core/framework/run_handler_util.h | 21 +- .../core/framework/run_handler_util_test.cc | 39 ++ 6 files changed, 589 insertions(+), 99 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 2d61ec49a0b..8369046aa85 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -4468,11 +4468,18 @@ tf_cc_test( srcs = ["framework/run_handler_test.cc"], linkstatic = tf_kernel_tests_linkstatic(), deps = [ + ":core_cpu", + ":direct_session_internal", ":framework_internal", ":lib", ":lib_internal", + ":protos_all_cc", + ":tensor_testutil", ":test", ":test_main", + ":testlib", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:matmul_op", "//third_party/eigen3", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", diff --git a/tensorflow/core/framework/run_handler.cc b/tensorflow/core/framework/run_handler.cc index 2fcbf3807ab..448082b99e1 100644 --- a/tensorflow/core/framework/run_handler.cc +++ b/tensorflow/core/framework/run_handler.cc @@ -98,6 +98,55 @@ class RunHandlerEnvironment { typedef typename RunHandlerEnvironment::Task Task; typedef Eigen::RunQueue Queue; +// To reduce cache misses, we use a doubly-linked list of Waiter structs and +// queue them in LIFO order rather than the FIFO order used by a single +// condition variable. +struct Waiter { + Waiter() { + next = this; + prev = this; + } + condition_variable cv; + mutex mu; + Waiter* next; + Waiter* prev; +}; + +void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex, + int max_sleep_micros) { + { + mutex_lock l(*mutex); + CHECK_EQ(waiter->next, waiter); // Crash OK. + CHECK_EQ(waiter->prev, waiter); // Crash OK. + + // Add waiter to the LIFO queue + waiter->prev = queue_head; + waiter->next = queue_head->next; + waiter->next->prev = waiter; + waiter->prev->next = waiter; + } + { + mutex_lock l(waiter->mu); + // Wait on the condition variable + waiter->cv.wait_for(l, std::chrono::microseconds(max_sleep_micros)); + } + + mutex_lock l(*mutex); + // Remove waiter from the LIFO queue. Note even when a waiter wakes up due + // to a notification we cannot conclude the waiter is not in the queue. + // This is due to the fact that a thread preempted right before notifying + // may resume after a waiter got re-added. + if (waiter->next != waiter) { + CHECK(waiter->prev != waiter); // Crash OK. + waiter->next->prev = waiter->prev; + waiter->prev->next = waiter->next; + waiter->next = waiter; + waiter->prev = waiter; + } else { + CHECK_EQ(waiter->prev, waiter); // Crash OK. + } +} + class ThreadWorkSource { public: ThreadWorkSource() @@ -155,11 +204,32 @@ class ThreadWorkSource { if (max_rank_to_wakeup > 0 && rank_.load(std::memory_order_relaxed) <= max_rank_to_wakeup) { Waiter* w = nullptr; + bool use_sub_thread_pool = ParamFromEnvBoolWithDefault( + "TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false); + + Waiter* waiter_queue; + mutex* waiter_queue_mu; + if (use_sub_thread_pool) { + // When we use multiple sub thread pools, free threads wait on sub + // thread pool waiting queues. Wake up threads from sub thread waiting + // queues. + // The waiting queues are defined at RunHandlerPool. + // Get the waiter_queue and coresponding mutex. Note, the thread work + // source may change afterwards if a new request comes or an old request + // finishes. + tf_shared_lock lock(run_handler_waiter_mu_); + waiter_queue = sub_thread_pool_waiter_; + waiter_queue_mu = sub_thread_pool_waiter_mu_; + } else { + waiter_queue = &queue_waiters_; + waiter_queue_mu = &waiters_mu_; + } + { - mutex_lock l(waiters_mu_); - if (queue_waiters_.next != &queue_waiters_) { + mutex_lock l(*waiter_queue_mu); + if (waiter_queue->next != waiter_queue) { // Remove waiter from the LIFO queue - w = queue_waiters_.next; + w = waiter_queue->next; CHECK(w->prev != w); CHECK(w->next != w); @@ -187,43 +257,25 @@ class ThreadWorkSource { Task PopBlockingTask() { return blocking_work_queue_.PopBack(); } - Task PopNonBlockingTask(int index) { - return non_blocking_work_queues_[index]->queue.PopBack(); + Task PopNonBlockingTask(int start_index, bool search_from_all_queue) { + Task t; + unsigned sharding_factor = NonBlockingWorkShardingFactor(); + for (unsigned j = 0; j < sharding_factor; ++j) { + t = non_blocking_work_queues_[(start_index + j) % sharding_factor] + ->queue.PopBack(); + if (t.f) { + return t; + } + if (!search_from_all_queue) { + break; + } + } + return t; } void WaitForWork(int max_sleep_micros) { thread_local Waiter waiter; - { - mutex_lock l(waiters_mu_); - CHECK_EQ(waiter.next, &waiter); - CHECK_EQ(waiter.prev, &waiter); - - // Add waiter to the LIFO queue - waiter.prev = &queue_waiters_; - waiter.next = queue_waiters_.next; - waiter.next->prev = &waiter; - waiter.prev->next = &waiter; - } - { - mutex_lock l(waiter.mu); - // Wait on the condition variable - waiter.cv.wait_for(l, std::chrono::microseconds(max_sleep_micros)); - } - - mutex_lock l(waiters_mu_); - // Remove waiter from the LIFO queue. Note even when a waiter wakes up due - // to a notification we cannot conclude the waiter is not in the queue. - // This is due to the fact that a thread preempted right before notifying - // may resume after a waiter got re-added. - if (waiter.next != &waiter) { - CHECK(waiter.prev != &waiter); - waiter.next->prev = waiter.prev; - waiter.prev->next = waiter.next; - waiter.next = &waiter; - waiter.prev = &waiter; - } else { - CHECK_EQ(waiter.prev, &waiter); - } + WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros); } int TaskQueueSize(bool is_blocking) { @@ -243,6 +295,12 @@ class ThreadWorkSource { void SetTracemeId(int64 value) { traceme_id_ = value; } void SetRank(int64 value) { rank_ = value; } + void SetWaiter(Waiter* waiter, mutex* mutex) { + mutex_lock l(run_handler_waiter_mu_); + sub_thread_pool_waiter_ = waiter; + sub_thread_pool_waiter_mu_ = mutex; + } + int64 GetInflightTaskCount(bool is_blocking) { std::atomic* counter = is_blocking ? &blocking_inflight_ : &non_blocking_inflight_; @@ -274,20 +332,6 @@ class ThreadWorkSource { } private: - // To reduce cache misses, we use a doubly-linked list of Waiter structs and - // queue them in LIFO order rather than the FIFO order used by a single - // condition variable. - struct Waiter { - Waiter() { - next = this; - prev = this; - } - condition_variable cv; - mutex mu; - Waiter* next; - Waiter* prev; - }; - struct NonBlockingQueue { mutex queue_op_mu; char pad[128]; @@ -307,6 +351,10 @@ class ThreadWorkSource { Waiter queue_waiters_ GUARDED_BY(waiters_mu_); std::atomic traceme_id_; std::atomic rank_; + + mutex run_handler_waiter_mu_; + mutex* sub_thread_pool_waiter_mu_ GUARDED_BY(run_handler_waiter_mu_); + Waiter* sub_thread_pool_waiter_ GUARDED_BY(run_handler_waiter_mu_); }; class RunHandlerThreadPool { @@ -319,25 +367,33 @@ class RunHandlerThreadPool { RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads, Env* env, const ThreadOptions& thread_options, - const string& name) + const string& name, + Eigen::MaxSizeVector* waiters_mu, + Eigen::MaxSizeVector* queue_waiters) : num_threads_(num_blocking_threads + num_non_blocking_threads), num_blocking_threads_(num_blocking_threads), num_non_blocking_threads_(num_non_blocking_threads), thread_data_(num_threads_), env_(env, thread_options, name), - name_(name) { + name_(name), + waiters_mu_(waiters_mu), + queue_waiters_(queue_waiters), + use_sub_thread_pool_(ParamFromEnvBoolWithDefault( + "TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)), + num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault( + "TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", + std::vector( + {num_blocking_threads / 2, + num_blocking_threads - num_blocking_threads / 2}))), + sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault( + "TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", + std::vector({0, 0.4}))), + sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault( + "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", + std::vector({0.4, 1}))) { VLOG(1) << "Creating RunHandlerThreadPool " << name << " with " << num_blocking_threads_ << " blocking threads and " << num_non_blocking_threads_ << " non-blocking threads."; - cancelled_ = false; - - thread_data_.resize(num_threads_); - for (int i = 0; i < num_threads_; i++) { - thread_data_[i].thread.reset( - env_.CreateThread([this, i, num_blocking_threads]() { - WorkerLoop(i, i < num_blocking_threads); - })); - } } ~RunHandlerThreadPool() { @@ -353,6 +409,26 @@ class RunHandlerThreadPool { } } + void Start() { + cancelled_ = false; + thread_data_.resize(num_threads_); + int num_blocking_threads = num_blocking_threads_; + for (int i = 0; i < num_threads_; i++) { + int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1; + for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) { + if (i < num_threads_in_sub_thread_pool_[j]) { + sub_thread_pool_id = j; + break; + } + } + thread_data_[i].sub_thread_pool_id = sub_thread_pool_id; + thread_data_[i].thread.reset( + env_.CreateThread([this, i, num_blocking_threads]() { + WorkerLoop(i, i < num_blocking_threads); + })); + } + } + void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking, std::function fn) { Task t = env_.CreateTask(std::move(fn)); @@ -384,30 +460,37 @@ class RunHandlerThreadPool { return; } thread_data_[tid].thread_work_sources.resize(0); - thread_data_[tid].thread_work_sources.emplace_back( - thread_work_sources[start_request_idx]); - // The number of shards for the queue. Threads in each shard will prioritize - // different thread_work_sources. Increase the number of shards could - // decrease the contention in the queue. - // For example, when num_shards == 1: - // thread_work_sources are ordered as start_request_idx, 0, 1, 2, 3, 4 ... - // for all threads. - // When num_shards == 2: - // thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3, - // 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2, - // 4... for the other half of the threads. - int num_shards = ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1); - int token = tid % num_shards; - for (int i = 0; i < num_shards; ++i) { - for (int j = token; j < thread_work_sources.size(); j += num_shards) { - if (j != start_request_idx) { - thread_data_[tid].thread_work_sources.emplace_back( - thread_work_sources[j]); - } + + if (use_sub_thread_pool_) { + for (int i = 0; i < thread_work_sources.size(); ++i) { + thread_data_[tid].thread_work_sources.emplace_back( + thread_work_sources[i]); } - token = (token + 1) % num_shards; + } else { + thread_data_[tid].thread_work_sources.emplace_back( + thread_work_sources[start_request_idx]); + // The number of shards for the queue. Threads in each shard will + // prioritize different thread_work_sources. Increase the number of shards + // could decrease the contention in the queue. For example, when + // num_shards == 1: thread_work_sources are ordered as start_request_idx, + // 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2: + // thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3, + // 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2, + // 4... for the other half of the threads. + int num_shards = + ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1); + int token = tid % num_shards; + for (int i = 0; i < num_shards; ++i) { + for (int j = token; j < thread_work_sources.size(); j += num_shards) { + if (j != start_request_idx) { + thread_data_[tid].thread_work_sources.emplace_back( + thread_work_sources[j]); + } + } + token = (token + 1) % num_shards; + } + thread_data_[tid].sources_not_empty.notify_all(); } - thread_data_[tid].sources_not_empty.notify_all(); } PerThread* GetPerThread() { @@ -434,13 +517,26 @@ class RunHandlerThreadPool { void WorkerLoop(int thread_id, bool may_steal_blocking_work); + // Search tasks from Requets range searching_range_start to + // searching_range_end. If there is no tasks in the search range and + // may_steal_blocking_work is true, then search from all reuqests. + Task FindTask( + int searching_range_start, int searching_range_end, int thread_id, + int sub_thread_pool_id, int max_blocking_inflight, + bool may_steal_blocking_work, + const Eigen::MaxSizeVector& thread_work_sources, + bool* task_from_blocking_queue, ThreadWorkSource** tws); + void WaitForWork(bool is_blocking, int thread_id, int32 max_blocking_inflight); + void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id); + private: struct ThreadData { ThreadData() : version(0), + current_index(0), thread_work_sources(static_cast( ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", kMaxConcurrentHandlers))) {} @@ -448,7 +544,9 @@ class RunHandlerThreadPool { uint64 version; condition_variable sources_not_empty; std::unique_ptr thread; + int current_index; Eigen::MaxSizeVector thread_work_sources GUARDED_BY(mu); + int sub_thread_pool_id; }; const int num_threads_; @@ -458,8 +556,58 @@ class RunHandlerThreadPool { RunHandlerEnvironment env_; std::atomic cancelled_; string name_; + Eigen::MaxSizeVector* waiters_mu_; + Eigen::MaxSizeVector* queue_waiters_; + + bool use_sub_thread_pool_; + std::vector num_threads_in_sub_thread_pool_; + + // Threads in each sub thread pool will search tasks from the given + // start_request_percentage to end_request_percentage in a round robin + // fashion. + std::vector sub_thread_pool_start_request_percentage_; + std::vector sub_thread_pool_end_request_percentage_; }; +Task RunHandlerThreadPool::FindTask( + int searching_range_start, int searching_range_end, int thread_id, + int sub_thread_pool_id, int max_blocking_inflight, + bool may_steal_blocking_work, + const Eigen::MaxSizeVector& thread_work_sources, + bool* task_from_blocking_queue, ThreadWorkSource** tws) { + Task t; + int current_index = thread_data_[thread_id].current_index; + *task_from_blocking_queue = false; + + // TODO(chaox): Chagne the search algorithm from round robin to random + // walk. + for (int i = 0; i < searching_range_end - searching_range_start; ++i) { + if (current_index >= searching_range_end) { + current_index = searching_range_start; + } + *tws = thread_work_sources[current_index]; + ++current_index; + + // For blocking thread, search for blocking tasks first. + if (may_steal_blocking_work && + (*tws)->GetInflightTaskCount(true) < max_blocking_inflight) { + t = (*tws)->PopBlockingTask(); + if (t.f) { + *task_from_blocking_queue = true; + break; + } + } + + // Search for non-blocking tasks. + t = (*tws)->PopNonBlockingTask(thread_id, true); + if (t.f) { + break; + } + } + thread_data_[thread_id].current_index = current_index; + return t; +} + // Main worker thread loop. void RunHandlerThreadPool::WorkerLoop(int thread_id, bool may_steal_blocking_work) { @@ -474,11 +622,47 @@ void RunHandlerThreadPool::WorkerLoop(int thread_id, bool task_from_blocking_queue = true; Eigen::MaxSizeVector* thread_work_sources = &thread_data_[thread_id].thread_work_sources; - { + int sub_thread_pool_id; + if (use_sub_thread_pool_) { + // The mutex is not hot since its per thread and can only be held + // by some other thread when a session run starts/finishes. + mutex_lock l(thread_data_[thread_id].mu); + sub_thread_pool_id = thread_data_[thread_id].sub_thread_pool_id; + int active_requests = thread_work_sources->size(); + if (may_steal_blocking_work) { + // Each thread will first look for tasks from requests that belongs to + // its sub thread pool. + t = FindTask( + active_requests * + sub_thread_pool_start_request_percentage_[sub_thread_pool_id], + active_requests * + sub_thread_pool_end_request_percentage_[sub_thread_pool_id], + thread_id, sub_thread_pool_id, kMaxBlockingInflight, + /*may_steal_blocking_work=*/true, *thread_work_sources, + &task_from_blocking_queue, &tws); + if (!t.f) { + // Search from all requests if the thread cannot find tasks from + // requests that belong to its own sub thread pool. + t = FindTask(0, active_requests, thread_id, sub_thread_pool_id, + kMaxBlockingInflight, + /*may_steal_blocking_work=*/true, *thread_work_sources, + &task_from_blocking_queue, &tws); + } + } else { + // For non-blocking threads, it will always search from all pending + // requests. + t = FindTask(0, active_requests, thread_id, sub_thread_pool_id, + kMaxBlockingInflight, + /*may_steal_blocking_work=*/false, *thread_work_sources, + &task_from_blocking_queue, &tws); + } + } else { // The mutex is not hot since its per thread and can only be held // by some other thread when a session run starts/finishes. mutex_lock l(thread_data_[thread_id].mu); + // TODO(chaox): Refactor the following code to share the logic with + // FindTask. for (int i = 0; i < thread_work_sources->size(); ++i) { tws = (*thread_work_sources)[i]; // We want a smallish numbers of inter threads since @@ -495,20 +679,16 @@ void RunHandlerThreadPool::WorkerLoop(int thread_id, // Always look for any work from the "primary" work source. // This way when we wake up a thread for a new closure we are // guaranteed it can be worked on. - for (int j = 0; j < tws->NonBlockingWorkShardingFactor(); ++j) { - t = tws->PopNonBlockingTask((j + thread_id) % - tws->NonBlockingWorkShardingFactor()); - if (t.f) { - task_from_blocking_queue = false; - break; - } + t = tws->PopNonBlockingTask(thread_id, true); + if (t.f) { + task_from_blocking_queue = false; + break; } if (t.f) { break; } } else { - t = tws->PopNonBlockingTask(thread_id % - tws->NonBlockingWorkShardingFactor()); + t = tws->PopNonBlockingTask(thread_id, false); if (t.f) { task_from_blocking_queue = false; break; @@ -542,12 +722,30 @@ void RunHandlerThreadPool::WorkerLoop(int thread_id, << (*thread_work_sources)[i]->ToString(); } } - - WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight); + if (use_sub_thread_pool_) { + WaitForWorkInSubThreadPool(may_steal_blocking_work, sub_thread_pool_id); + } else { + WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight); + } } } } +void RunHandlerThreadPool::WaitForWorkInSubThreadPool(bool is_blocking, + int sub_thread_pool_id) { + const int kMaxSleepMicros = 250; + + // The non-blocking thread will just sleep. + if (!is_blocking) { + Env::Default()->SleepForMicroseconds(kMaxSleepMicros); + return; + } + + thread_local Waiter waiter; + WaitOnWaiter(&waiter, &(*queue_waiters_)[sub_thread_pool_id], + &(*waiters_mu_)[sub_thread_pool_id], kMaxSleepMicros); +} + void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id, int32 max_blocking_inflight) { const int kMaxSleepMicros = 250; @@ -636,16 +834,33 @@ class RunHandlerPool::Impl { explicit Impl(int num_inter_op_threads, int num_intra_op_threads) : max_handlers_(static_cast(ParamFromEnvWithDefault( "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", kMaxConcurrentHandlers))), + waiters_mu_( + ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)), + queue_waiters_( + ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)), run_handler_thread_pool_(new RunHandlerThreadPool( num_inter_op_threads, num_intra_op_threads, Env::Default(), - ThreadOptions(), "tf_run_handler_pool")), + ThreadOptions(), "tf_run_handler_pool", &waiters_mu_, + &queue_waiters_)), iterations_(0), - version_(0) { + version_(0), + sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault( + "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", + std::vector({1}))) { VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_; for (int i = 0; i < max_handlers_; ++i) { handlers_.emplace_back(new RunHandler::Impl(this)); free_handlers_.push_back(handlers_.back().get()); } + queue_waiters_.resize( + ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)); + waiters_mu_.resize( + ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)); + for (auto& queue_waiter : queue_waiters_) { + queue_waiter.next = &queue_waiter; + queue_waiter.prev = &queue_waiter; + } + run_handler_thread_pool_->Start(); } ~Impl() { @@ -693,6 +908,19 @@ class RunHandlerPool::Impl { for (int i = 0; i < num_active_requests; ++i) { (*thread_work_sources)[i] = sorted_active_handlers_[i]->tws(); (*thread_work_sources)[i]->SetRank(i); + int sub_thread_pool_id = + sub_thread_pool_end_request_percentage_.size() - 1; + for (int j = 0; j < sub_thread_pool_end_request_percentage_.size(); + ++j) { + if (i < num_active_requests * + sub_thread_pool_end_request_percentage_[j]) { + sub_thread_pool_id = j; + break; + } + } + (*thread_work_sources)[i]->SetWaiter( + &queue_waiters_[sub_thread_pool_id], + &waiters_mu_[sub_thread_pool_id]); } version = ++version_; } @@ -738,6 +966,19 @@ class RunHandlerPool::Impl { for (int i = 0; i < num_active_requests; ++i) { (*thread_work_sources)[i] = sorted_active_handlers_[i]->tws(); (*thread_work_sources)[i]->SetRank(i); + int sub_thread_pool_id = + sub_thread_pool_end_request_percentage_.size() - 1; + for (int j = 0; j < sub_thread_pool_end_request_percentage_.size(); + ++j) { + if (i < num_active_requests * + sub_thread_pool_end_request_percentage_[j]) { + sub_thread_pool_id = j; + break; + } + } + (*thread_work_sources)[i]->SetWaiter( + &queue_waiters_[sub_thread_pool_id], + &waiters_mu_[sub_thread_pool_id]); } version = ++version_; LogInfo(); @@ -759,6 +1000,9 @@ class RunHandlerPool::Impl { // inference). const int max_handlers_; + Eigen::MaxSizeVector waiters_mu_; + Eigen::MaxSizeVector queue_waiters_; + std::unique_ptr run_handler_thread_pool_; // Thread compatible part used only by lock under RunHandlerPool. // Handlers are sorted by start time. @@ -773,6 +1017,7 @@ class RunHandlerPool::Impl { condition_variable one_handler_free_; mutex mu_; int64 version_ GUARDED_BY(mu_); + const std::vector sub_thread_pool_end_request_percentage_; }; void RunHandlerPool::Impl::RecomputePoolStats( diff --git a/tensorflow/core/framework/run_handler_test.cc b/tensorflow/core/framework/run_handler_test.cc index 263ef16796f..71b1fbc8d8d 100644 --- a/tensorflow/core/framework/run_handler_test.cc +++ b/tensorflow/core/framework/run_handler_test.cc @@ -24,11 +24,17 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/synchronization/barrier.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" namespace tensorflow { namespace { @@ -72,5 +78,132 @@ TEST(RunHandlerUtilTest, TestBasicScheduling) { counter.Wait(); } +SessionOptions DefaultSessionOptions() { + SessionOptions options; + (*options.config.mutable_device_count())["CPU"] = 2; + return options; +} + +std::unique_ptr CreateSession() { + return std::unique_ptr(NewSession(DefaultSessionOptions())); +} + +class RunHandlerTest : public ::testing::Test { + public: + void Initialize(std::initializer_list a_values) { + Graph graph(OpRegistry::Global()); + + Tensor a_tensor(DT_FLOAT, TensorShape({2, 2})); + test::FillValues(&a_tensor, a_values); + Node* a = test::graph::Constant(&graph, a_tensor); + a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); + a_ = a->name(); + + Tensor x_tensor(DT_FLOAT, TensorShape({2, 1})); + test::FillValues(&x_tensor, {1, 1}); + Node* x = test::graph::Constant(&graph, x_tensor); + x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1"); + x_ = x->name(); + + // y = A * x + Node* y = test::graph::Matmul(&graph, a, x, false, false); + y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); + y_ = y->name(); + + Node* y_neg = test::graph::Unary(&graph, "Neg", y); + y_neg_ = y_neg->name(); + y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1"); + + Node* z = test::graph::Unary(&graph, "Identity", y_neg); + z_ = z->name(); + z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1"); + + graph.ToGraphDef(&def_); + + ASSERT_EQ(setenv("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", "2", true), 0); + ASSERT_EQ( + setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "8,8", true), + 0); + ASSERT_EQ(setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", + "0,0.4", true), + 0); + ASSERT_EQ(setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", + "0.4,1", true), + 0); + ASSERT_EQ(setenv("TF_NUM_INTEROP_THREADS", "16", true), 0); + } + + string a_; + string x_; + string y_; + string y_neg_; + string z_; + GraphDef def_; +}; + +TEST_F(RunHandlerTest, UseRunHandlerPoolEnableSubPool) { + Initialize({3, 2, -1, 0}); + auto session = CreateSession(); + ASSERT_TRUE(session != nullptr); + EXPECT_EQ(::tensorflow::Status::OK(), session->Create(def_)); + std::vector> inputs; + + // Request two targets: one fetch output and one non-fetched output. + std::vector output_names = {y_ + ":0"}; + std::vector target_nodes = {y_neg_}; + std::vector outputs; + + // Prepares RunOptions and RunMetadata + RunOptions run_options; + run_options.mutable_experimental()->set_use_run_handler_pool(true); + + Status s = session->Run(run_options, inputs, output_names, target_nodes, + &outputs, nullptr); + EXPECT_EQ(::tensorflow::Status::OK(), s); + + ASSERT_EQ(1, outputs.size()); + // The first output should be initialized and have the correct + // output. + auto mat = outputs[0].matrix(); + ASSERT_TRUE(outputs[0].IsInitialized()); + EXPECT_FLOAT_EQ(5.0, mat(0, 0)); +} + +TEST_F(RunHandlerTest, TestConcurrencyUseRunHandlerPool) { + Initialize({1, 2, 3, 4}); + auto session = CreateSession(); + ASSERT_TRUE(session != nullptr); + EXPECT_EQ(::tensorflow::Status::OK(), session->Create(def_)); + + RunOptions run_options; + run_options.mutable_experimental()->set_use_run_handler_pool(true); + + // Fill in the input and ask for the output + thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4); + + // Run the graph 1000 times in 4 different threads concurrently. + std::vector output_names = {y_ + ":0"}; + auto fn = [&session, output_names, run_options]() { + for (int i = 0; i < 1000; ++i) { + std::vector> inputs; + std::vector outputs; + // Run the graph + Status s = session->Run(run_options, inputs, output_names, {}, &outputs, + nullptr); + EXPECT_EQ(::tensorflow::Status::OK(), s); + ASSERT_EQ(1, outputs.size()); + auto mat = outputs[0].matrix(); + EXPECT_FLOAT_EQ(3.0, mat(0, 0)); + } + }; + + for (int i = 0; i < 4; ++i) { + tp->Schedule(fn); + } + + // Wait for the functions to finish. + delete tp; +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/framework/run_handler_util.cc b/tensorflow/core/framework/run_handler_util.cc index ebdc670a925..c832a643385 100644 --- a/tensorflow/core/framework/run_handler_util.cc +++ b/tensorflow/core/framework/run_handler_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/str_util.h" namespace tensorflow { @@ -29,6 +30,54 @@ double ParamFromEnvWithDefault(const std::string& var_name, return (val && strings::safe_strtod(val, &num)) ? num : default_value; } +std::vector ParamFromEnvWithDefault(const std::string& var_name, + std::vector default_value) { + const char* val = std::getenv(var_name.c_str()); + if (!val) { + return default_value; + } + std::vector splits = str_util::Split(val, ","); + std::vector result; + result.reserve(splits.size()); + for (auto& split : splits) { + double num; + if (strings::safe_strtod(split, &num)) { + result.push_back(num); + } else { + LOG(ERROR) << "Wrong format for " << var_name << ". Use default value."; + return default_value; + } + } + return result; +} + +std::vector ParamFromEnvWithDefault(const std::string& var_name, + std::vector default_value) { + const char* val = std::getenv(var_name.c_str()); + if (!val) { + return default_value; + } + std::vector splits = str_util::Split(val, ","); + std::vector result; + result.reserve(splits.size()); + for (auto& split : splits) { + int num; + if (strings::safe_strto32(split, &num)) { + result.push_back(num); + } else { + LOG(ERROR) << "Wrong format for " << var_name << ". Use default value."; + return default_value; + } + } + return result; +} + +bool ParamFromEnvBoolWithDefault(const std::string& var_name, + bool default_value) { + const char* val = std::getenv(var_name.c_str()); + return (val) ? str_util::Lowercase(val) == "true" : default_value; +} + void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads, int min_threads_per_request, std::vector* start_vec, diff --git a/tensorflow/core/framework/run_handler_util.h b/tensorflow/core/framework/run_handler_util.h index 864e6e698fc..982f06fb7e0 100644 --- a/tensorflow/core/framework/run_handler_util.h +++ b/tensorflow/core/framework/run_handler_util.h @@ -54,10 +54,27 @@ void ComputeInterOpStealingRanges(int num_threads, int min_threads_per_domain, std::vector ChooseRequestsWithExponentialDistribution( int num_active_requests, int num_threads); -// Loop environment variable named 'var_name' and return the value if it exist -// and can be parsed. Return 'default_value' otherwise. +// Look up environment variable named 'var_name' and return the value if it +// exist and can be parsed. Return 'default_value' otherwise. double ParamFromEnvWithDefault(const std::string& var_name, double default_value); +// Look up environment variable named 'var_name' and return the value if it +// exist and can be parsed. The value must be in format val1,val2... Return +// 'default_value' otherwise. +std::vector ParamFromEnvWithDefault(const std::string& var_name, + std::vector default_value); + +// Look up environment variable named 'var_name' and return the value if it +// exist and can be parsed. The value must be in format val1,val2... Return +// 'default_value' otherwise. +std::vector ParamFromEnvWithDefault(const std::string& var_name, + std::vector default_value); + +// Look up environment variable named 'var_name' and return the value if it +// exist and can be parsed. Return 'default_value' otherwise. +bool ParamFromEnvBoolWithDefault(const std::string& var_name, + bool default_value); + } // end namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_ diff --git a/tensorflow/core/framework/run_handler_util_test.cc b/tensorflow/core/framework/run_handler_util_test.cc index 7f85118671c..769991920d1 100644 --- a/tensorflow/core/framework/run_handler_util_test.cc +++ b/tensorflow/core/framework/run_handler_util_test.cc @@ -124,5 +124,44 @@ TEST(RunHandlerUtilTest, TestExponentialRequestDistribution) { ASSERT_EQ(actual_distribution, expected_distribution); } +TEST(RunHandlerUtilTest, TestParamFromEnvWithDefault) { + std::vector result = ParamFromEnvWithDefault( + "RUN_HANDLER_TEST_ENV", std::vector{0, 0, 0}); + EXPECT_EQ(result.size(), 3); + EXPECT_EQ(result[0], 0); + EXPECT_EQ(result[1], 0); + EXPECT_EQ(result[2], 0); + + std::vector result2 = ParamFromEnvWithDefault("RUN_HANDLER_TEST_ENV", + std::vector{0, 0, 0}); + EXPECT_EQ(result2.size(), 3); + EXPECT_EQ(result2[0], 0); + EXPECT_EQ(result2[1], 0); + EXPECT_EQ(result2[2], 0); + + bool result3 = + ParamFromEnvBoolWithDefault("RUN_HANDLER_TEST_ENV_BOOL", false); + EXPECT_EQ(result3, false); + + // Set environment variable. + EXPECT_EQ(setenv("RUN_HANDLER_TEST_ENV", "1,2,3", true), 0); + result = ParamFromEnvWithDefault("RUN_HANDLER_TEST_ENV", + std::vector{0, 0, 0}); + EXPECT_EQ(result.size(), 3); + EXPECT_EQ(result[0], 1); + EXPECT_EQ(result[1], 2); + EXPECT_EQ(result[2], 3); + result2 = ParamFromEnvWithDefault("RUN_HANDLER_TEST_ENV", + std::vector{0, 0, 0}); + EXPECT_EQ(result.size(), 3); + EXPECT_EQ(result2[0], 1); + EXPECT_EQ(result2[1], 2); + EXPECT_EQ(result2[2], 3); + + EXPECT_EQ(setenv("RUN_HANDLER_TEST_ENV_BOOL", "true", true), 0); + result3 = ParamFromEnvBoolWithDefault("RUN_HANDLER_TEST_ENV_BOOL", false); + EXPECT_EQ(result3, true); +} + } // namespace } // namespace tensorflow