Introduce sub thread pools to run handler thread pool.

PiperOrigin-RevId: 279237293
Change-Id: I8f1945105d27d1ed06eee8bb914bfaa06fec6c1f
This commit is contained in:
Chao Xie 2019-11-07 21:34:49 -08:00 committed by TensorFlower Gardener
parent af019188ad
commit bfce5bae88
6 changed files with 589 additions and 99 deletions

View File

@ -4468,11 +4468,18 @@ tf_cc_test(
srcs = ["framework/run_handler_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":core_cpu",
":direct_session_internal",
":framework_internal",
":lib",
":lib_internal",
":protos_all_cc",
":tensor_testutil",
":test",
":test_main",
":testlib",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:matmul_op",
"//third_party/eigen3",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",

View File

@ -98,6 +98,55 @@ class RunHandlerEnvironment {
typedef typename RunHandlerEnvironment::Task Task;
typedef Eigen::RunQueue<Task, 1024> Queue;
// To reduce cache misses, we use a doubly-linked list of Waiter structs and
// queue them in LIFO order rather than the FIFO order used by a single
// condition variable.
struct Waiter {
Waiter() {
next = this;
prev = this;
}
condition_variable cv;
mutex mu;
Waiter* next;
Waiter* prev;
};
void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
int max_sleep_micros) {
{
mutex_lock l(*mutex);
CHECK_EQ(waiter->next, waiter); // Crash OK.
CHECK_EQ(waiter->prev, waiter); // Crash OK.
// Add waiter to the LIFO queue
waiter->prev = queue_head;
waiter->next = queue_head->next;
waiter->next->prev = waiter;
waiter->prev->next = waiter;
}
{
mutex_lock l(waiter->mu);
// Wait on the condition variable
waiter->cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
}
mutex_lock l(*mutex);
// Remove waiter from the LIFO queue. Note even when a waiter wakes up due
// to a notification we cannot conclude the waiter is not in the queue.
// This is due to the fact that a thread preempted right before notifying
// may resume after a waiter got re-added.
if (waiter->next != waiter) {
CHECK(waiter->prev != waiter); // Crash OK.
waiter->next->prev = waiter->prev;
waiter->prev->next = waiter->next;
waiter->next = waiter;
waiter->prev = waiter;
} else {
CHECK_EQ(waiter->prev, waiter); // Crash OK.
}
}
class ThreadWorkSource {
public:
ThreadWorkSource()
@ -155,11 +204,32 @@ class ThreadWorkSource {
if (max_rank_to_wakeup > 0 &&
rank_.load(std::memory_order_relaxed) <= max_rank_to_wakeup) {
Waiter* w = nullptr;
bool use_sub_thread_pool = ParamFromEnvBoolWithDefault(
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
Waiter* waiter_queue;
mutex* waiter_queue_mu;
if (use_sub_thread_pool) {
// When we use multiple sub thread pools, free threads wait on sub
// thread pool waiting queues. Wake up threads from sub thread waiting
// queues.
// The waiting queues are defined at RunHandlerPool.
// Get the waiter_queue and coresponding mutex. Note, the thread work
// source may change afterwards if a new request comes or an old request
// finishes.
tf_shared_lock lock(run_handler_waiter_mu_);
waiter_queue = sub_thread_pool_waiter_;
waiter_queue_mu = sub_thread_pool_waiter_mu_;
} else {
waiter_queue = &queue_waiters_;
waiter_queue_mu = &waiters_mu_;
}
{
mutex_lock l(waiters_mu_);
if (queue_waiters_.next != &queue_waiters_) {
mutex_lock l(*waiter_queue_mu);
if (waiter_queue->next != waiter_queue) {
// Remove waiter from the LIFO queue
w = queue_waiters_.next;
w = waiter_queue->next;
CHECK(w->prev != w);
CHECK(w->next != w);
@ -187,43 +257,25 @@ class ThreadWorkSource {
Task PopBlockingTask() { return blocking_work_queue_.PopBack(); }
Task PopNonBlockingTask(int index) {
return non_blocking_work_queues_[index]->queue.PopBack();
Task PopNonBlockingTask(int start_index, bool search_from_all_queue) {
Task t;
unsigned sharding_factor = NonBlockingWorkShardingFactor();
for (unsigned j = 0; j < sharding_factor; ++j) {
t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
->queue.PopBack();
if (t.f) {
return t;
}
if (!search_from_all_queue) {
break;
}
}
return t;
}
void WaitForWork(int max_sleep_micros) {
thread_local Waiter waiter;
{
mutex_lock l(waiters_mu_);
CHECK_EQ(waiter.next, &waiter);
CHECK_EQ(waiter.prev, &waiter);
// Add waiter to the LIFO queue
waiter.prev = &queue_waiters_;
waiter.next = queue_waiters_.next;
waiter.next->prev = &waiter;
waiter.prev->next = &waiter;
}
{
mutex_lock l(waiter.mu);
// Wait on the condition variable
waiter.cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
}
mutex_lock l(waiters_mu_);
// Remove waiter from the LIFO queue. Note even when a waiter wakes up due
// to a notification we cannot conclude the waiter is not in the queue.
// This is due to the fact that a thread preempted right before notifying
// may resume after a waiter got re-added.
if (waiter.next != &waiter) {
CHECK(waiter.prev != &waiter);
waiter.next->prev = waiter.prev;
waiter.prev->next = waiter.next;
waiter.next = &waiter;
waiter.prev = &waiter;
} else {
CHECK_EQ(waiter.prev, &waiter);
}
WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
}
int TaskQueueSize(bool is_blocking) {
@ -243,6 +295,12 @@ class ThreadWorkSource {
void SetTracemeId(int64 value) { traceme_id_ = value; }
void SetRank(int64 value) { rank_ = value; }
void SetWaiter(Waiter* waiter, mutex* mutex) {
mutex_lock l(run_handler_waiter_mu_);
sub_thread_pool_waiter_ = waiter;
sub_thread_pool_waiter_mu_ = mutex;
}
int64 GetInflightTaskCount(bool is_blocking) {
std::atomic<int64>* counter =
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
@ -274,20 +332,6 @@ class ThreadWorkSource {
}
private:
// To reduce cache misses, we use a doubly-linked list of Waiter structs and
// queue them in LIFO order rather than the FIFO order used by a single
// condition variable.
struct Waiter {
Waiter() {
next = this;
prev = this;
}
condition_variable cv;
mutex mu;
Waiter* next;
Waiter* prev;
};
struct NonBlockingQueue {
mutex queue_op_mu;
char pad[128];
@ -307,6 +351,10 @@ class ThreadWorkSource {
Waiter queue_waiters_ GUARDED_BY(waiters_mu_);
std::atomic<int64> traceme_id_;
std::atomic<int64> rank_;
mutex run_handler_waiter_mu_;
mutex* sub_thread_pool_waiter_mu_ GUARDED_BY(run_handler_waiter_mu_);
Waiter* sub_thread_pool_waiter_ GUARDED_BY(run_handler_waiter_mu_);
};
class RunHandlerThreadPool {
@ -319,25 +367,33 @@ class RunHandlerThreadPool {
RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
Env* env, const ThreadOptions& thread_options,
const string& name)
const string& name,
Eigen::MaxSizeVector<mutex>* waiters_mu,
Eigen::MaxSizeVector<Waiter>* queue_waiters)
: num_threads_(num_blocking_threads + num_non_blocking_threads),
num_blocking_threads_(num_blocking_threads),
num_non_blocking_threads_(num_non_blocking_threads),
thread_data_(num_threads_),
env_(env, thread_options, name),
name_(name) {
name_(name),
waiters_mu_(waiters_mu),
queue_waiters_(queue_waiters),
use_sub_thread_pool_(ParamFromEnvBoolWithDefault(
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
std::vector<int>(
{num_blocking_threads / 2,
num_blocking_threads - num_blocking_threads / 2}))),
sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
std::vector<double>({0, 0.4}))),
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
std::vector<double>({0.4, 1}))) {
VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
<< num_blocking_threads_ << " blocking threads and "
<< num_non_blocking_threads_ << " non-blocking threads.";
cancelled_ = false;
thread_data_.resize(num_threads_);
for (int i = 0; i < num_threads_; i++) {
thread_data_[i].thread.reset(
env_.CreateThread([this, i, num_blocking_threads]() {
WorkerLoop(i, i < num_blocking_threads);
}));
}
}
~RunHandlerThreadPool() {
@ -353,6 +409,26 @@ class RunHandlerThreadPool {
}
}
void Start() {
cancelled_ = false;
thread_data_.resize(num_threads_);
int num_blocking_threads = num_blocking_threads_;
for (int i = 0; i < num_threads_; i++) {
int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
if (i < num_threads_in_sub_thread_pool_[j]) {
sub_thread_pool_id = j;
break;
}
}
thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
thread_data_[i].thread.reset(
env_.CreateThread([this, i, num_blocking_threads]() {
WorkerLoop(i, i < num_blocking_threads);
}));
}
}
void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
std::function<void()> fn) {
Task t = env_.CreateTask(std::move(fn));
@ -384,30 +460,37 @@ class RunHandlerThreadPool {
return;
}
thread_data_[tid].thread_work_sources.resize(0);
thread_data_[tid].thread_work_sources.emplace_back(
thread_work_sources[start_request_idx]);
// The number of shards for the queue. Threads in each shard will prioritize
// different thread_work_sources. Increase the number of shards could
// decrease the contention in the queue.
// For example, when num_shards == 1:
// thread_work_sources are ordered as start_request_idx, 0, 1, 2, 3, 4 ...
// for all threads.
// When num_shards == 2:
// thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
// 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
// 4... for the other half of the threads.
int num_shards = ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
int token = tid % num_shards;
for (int i = 0; i < num_shards; ++i) {
for (int j = token; j < thread_work_sources.size(); j += num_shards) {
if (j != start_request_idx) {
thread_data_[tid].thread_work_sources.emplace_back(
thread_work_sources[j]);
}
if (use_sub_thread_pool_) {
for (int i = 0; i < thread_work_sources.size(); ++i) {
thread_data_[tid].thread_work_sources.emplace_back(
thread_work_sources[i]);
}
token = (token + 1) % num_shards;
} else {
thread_data_[tid].thread_work_sources.emplace_back(
thread_work_sources[start_request_idx]);
// The number of shards for the queue. Threads in each shard will
// prioritize different thread_work_sources. Increase the number of shards
// could decrease the contention in the queue. For example, when
// num_shards == 1: thread_work_sources are ordered as start_request_idx,
// 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2:
// thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
// 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
// 4... for the other half of the threads.
int num_shards =
ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
int token = tid % num_shards;
for (int i = 0; i < num_shards; ++i) {
for (int j = token; j < thread_work_sources.size(); j += num_shards) {
if (j != start_request_idx) {
thread_data_[tid].thread_work_sources.emplace_back(
thread_work_sources[j]);
}
}
token = (token + 1) % num_shards;
}
thread_data_[tid].sources_not_empty.notify_all();
}
thread_data_[tid].sources_not_empty.notify_all();
}
PerThread* GetPerThread() {
@ -434,13 +517,26 @@ class RunHandlerThreadPool {
void WorkerLoop(int thread_id, bool may_steal_blocking_work);
// Search tasks from Requets range searching_range_start to
// searching_range_end. If there is no tasks in the search range and
// may_steal_blocking_work is true, then search from all reuqests.
Task FindTask(
int searching_range_start, int searching_range_end, int thread_id,
int sub_thread_pool_id, int max_blocking_inflight,
bool may_steal_blocking_work,
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
bool* task_from_blocking_queue, ThreadWorkSource** tws);
void WaitForWork(bool is_blocking, int thread_id,
int32 max_blocking_inflight);
void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);
private:
struct ThreadData {
ThreadData()
: version(0),
current_index(0),
thread_work_sources(static_cast<int32>(
ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
kMaxConcurrentHandlers))) {}
@ -448,7 +544,9 @@ class RunHandlerThreadPool {
uint64 version;
condition_variable sources_not_empty;
std::unique_ptr<Thread> thread;
int current_index;
Eigen::MaxSizeVector<ThreadWorkSource*> thread_work_sources GUARDED_BY(mu);
int sub_thread_pool_id;
};
const int num_threads_;
@ -458,8 +556,58 @@ class RunHandlerThreadPool {
RunHandlerEnvironment env_;
std::atomic<bool> cancelled_;
string name_;
Eigen::MaxSizeVector<mutex>* waiters_mu_;
Eigen::MaxSizeVector<Waiter>* queue_waiters_;
bool use_sub_thread_pool_;
std::vector<int> num_threads_in_sub_thread_pool_;
// Threads in each sub thread pool will search tasks from the given
// start_request_percentage to end_request_percentage in a round robin
// fashion.
std::vector<double> sub_thread_pool_start_request_percentage_;
std::vector<double> sub_thread_pool_end_request_percentage_;
};
Task RunHandlerThreadPool::FindTask(
int searching_range_start, int searching_range_end, int thread_id,
int sub_thread_pool_id, int max_blocking_inflight,
bool may_steal_blocking_work,
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
bool* task_from_blocking_queue, ThreadWorkSource** tws) {
Task t;
int current_index = thread_data_[thread_id].current_index;
*task_from_blocking_queue = false;
// TODO(chaox): Chagne the search algorithm from round robin to random
// walk.
for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
if (current_index >= searching_range_end) {
current_index = searching_range_start;
}
*tws = thread_work_sources[current_index];
++current_index;
// For blocking thread, search for blocking tasks first.
if (may_steal_blocking_work &&
(*tws)->GetInflightTaskCount(true) < max_blocking_inflight) {
t = (*tws)->PopBlockingTask();
if (t.f) {
*task_from_blocking_queue = true;
break;
}
}
// Search for non-blocking tasks.
t = (*tws)->PopNonBlockingTask(thread_id, true);
if (t.f) {
break;
}
}
thread_data_[thread_id].current_index = current_index;
return t;
}
// Main worker thread loop.
void RunHandlerThreadPool::WorkerLoop(int thread_id,
bool may_steal_blocking_work) {
@ -474,11 +622,47 @@ void RunHandlerThreadPool::WorkerLoop(int thread_id,
bool task_from_blocking_queue = true;
Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
&thread_data_[thread_id].thread_work_sources;
{
int sub_thread_pool_id;
if (use_sub_thread_pool_) {
// The mutex is not hot since its per thread and can only be held
// by some other thread when a session run starts/finishes.
mutex_lock l(thread_data_[thread_id].mu);
sub_thread_pool_id = thread_data_[thread_id].sub_thread_pool_id;
int active_requests = thread_work_sources->size();
if (may_steal_blocking_work) {
// Each thread will first look for tasks from requests that belongs to
// its sub thread pool.
t = FindTask(
active_requests *
sub_thread_pool_start_request_percentage_[sub_thread_pool_id],
active_requests *
sub_thread_pool_end_request_percentage_[sub_thread_pool_id],
thread_id, sub_thread_pool_id, kMaxBlockingInflight,
/*may_steal_blocking_work=*/true, *thread_work_sources,
&task_from_blocking_queue, &tws);
if (!t.f) {
// Search from all requests if the thread cannot find tasks from
// requests that belong to its own sub thread pool.
t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
kMaxBlockingInflight,
/*may_steal_blocking_work=*/true, *thread_work_sources,
&task_from_blocking_queue, &tws);
}
} else {
// For non-blocking threads, it will always search from all pending
// requests.
t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
kMaxBlockingInflight,
/*may_steal_blocking_work=*/false, *thread_work_sources,
&task_from_blocking_queue, &tws);
}
} else {
// The mutex is not hot since its per thread and can only be held
// by some other thread when a session run starts/finishes.
mutex_lock l(thread_data_[thread_id].mu);
// TODO(chaox): Refactor the following code to share the logic with
// FindTask.
for (int i = 0; i < thread_work_sources->size(); ++i) {
tws = (*thread_work_sources)[i];
// We want a smallish numbers of inter threads since
@ -495,20 +679,16 @@ void RunHandlerThreadPool::WorkerLoop(int thread_id,
// Always look for any work from the "primary" work source.
// This way when we wake up a thread for a new closure we are
// guaranteed it can be worked on.
for (int j = 0; j < tws->NonBlockingWorkShardingFactor(); ++j) {
t = tws->PopNonBlockingTask((j + thread_id) %
tws->NonBlockingWorkShardingFactor());
if (t.f) {
task_from_blocking_queue = false;
break;
}
t = tws->PopNonBlockingTask(thread_id, true);
if (t.f) {
task_from_blocking_queue = false;
break;
}
if (t.f) {
break;
}
} else {
t = tws->PopNonBlockingTask(thread_id %
tws->NonBlockingWorkShardingFactor());
t = tws->PopNonBlockingTask(thread_id, false);
if (t.f) {
task_from_blocking_queue = false;
break;
@ -542,12 +722,30 @@ void RunHandlerThreadPool::WorkerLoop(int thread_id,
<< (*thread_work_sources)[i]->ToString();
}
}
WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight);
if (use_sub_thread_pool_) {
WaitForWorkInSubThreadPool(may_steal_blocking_work, sub_thread_pool_id);
} else {
WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight);
}
}
}
}
void RunHandlerThreadPool::WaitForWorkInSubThreadPool(bool is_blocking,
int sub_thread_pool_id) {
const int kMaxSleepMicros = 250;
// The non-blocking thread will just sleep.
if (!is_blocking) {
Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
return;
}
thread_local Waiter waiter;
WaitOnWaiter(&waiter, &(*queue_waiters_)[sub_thread_pool_id],
&(*waiters_mu_)[sub_thread_pool_id], kMaxSleepMicros);
}
void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id,
int32 max_blocking_inflight) {
const int kMaxSleepMicros = 250;
@ -636,16 +834,33 @@ class RunHandlerPool::Impl {
explicit Impl(int num_inter_op_threads, int num_intra_op_threads)
: max_handlers_(static_cast<int32>(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", kMaxConcurrentHandlers))),
waiters_mu_(
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
queue_waiters_(
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
run_handler_thread_pool_(new RunHandlerThreadPool(
num_inter_op_threads, num_intra_op_threads, Env::Default(),
ThreadOptions(), "tf_run_handler_pool")),
ThreadOptions(), "tf_run_handler_pool", &waiters_mu_,
&queue_waiters_)),
iterations_(0),
version_(0) {
version_(0),
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
std::vector<double>({1}))) {
VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
for (int i = 0; i < max_handlers_; ++i) {
handlers_.emplace_back(new RunHandler::Impl(this));
free_handlers_.push_back(handlers_.back().get());
}
queue_waiters_.resize(
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
waiters_mu_.resize(
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
for (auto& queue_waiter : queue_waiters_) {
queue_waiter.next = &queue_waiter;
queue_waiter.prev = &queue_waiter;
}
run_handler_thread_pool_->Start();
}
~Impl() {
@ -693,6 +908,19 @@ class RunHandlerPool::Impl {
for (int i = 0; i < num_active_requests; ++i) {
(*thread_work_sources)[i] = sorted_active_handlers_[i]->tws();
(*thread_work_sources)[i]->SetRank(i);
int sub_thread_pool_id =
sub_thread_pool_end_request_percentage_.size() - 1;
for (int j = 0; j < sub_thread_pool_end_request_percentage_.size();
++j) {
if (i < num_active_requests *
sub_thread_pool_end_request_percentage_[j]) {
sub_thread_pool_id = j;
break;
}
}
(*thread_work_sources)[i]->SetWaiter(
&queue_waiters_[sub_thread_pool_id],
&waiters_mu_[sub_thread_pool_id]);
}
version = ++version_;
}
@ -738,6 +966,19 @@ class RunHandlerPool::Impl {
for (int i = 0; i < num_active_requests; ++i) {
(*thread_work_sources)[i] = sorted_active_handlers_[i]->tws();
(*thread_work_sources)[i]->SetRank(i);
int sub_thread_pool_id =
sub_thread_pool_end_request_percentage_.size() - 1;
for (int j = 0; j < sub_thread_pool_end_request_percentage_.size();
++j) {
if (i < num_active_requests *
sub_thread_pool_end_request_percentage_[j]) {
sub_thread_pool_id = j;
break;
}
}
(*thread_work_sources)[i]->SetWaiter(
&queue_waiters_[sub_thread_pool_id],
&waiters_mu_[sub_thread_pool_id]);
}
version = ++version_;
LogInfo();
@ -759,6 +1000,9 @@ class RunHandlerPool::Impl {
// inference).
const int max_handlers_;
Eigen::MaxSizeVector<mutex> waiters_mu_;
Eigen::MaxSizeVector<Waiter> queue_waiters_;
std::unique_ptr<RunHandlerThreadPool> run_handler_thread_pool_;
// Thread compatible part used only by lock under RunHandlerPool.
// Handlers are sorted by start time.
@ -773,6 +1017,7 @@ class RunHandlerPool::Impl {
condition_variable one_handler_free_;
mutex mu_;
int64 version_ GUARDED_BY(mu_);
const std::vector<double> sub_thread_pool_end_request_percentage_;
};
void RunHandlerPool::Impl::RecomputePoolStats(

View File

@ -24,11 +24,17 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/synchronization/barrier.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace {
@ -72,5 +78,132 @@ TEST(RunHandlerUtilTest, TestBasicScheduling) {
counter.Wait();
}
SessionOptions DefaultSessionOptions() {
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
return options;
}
std::unique_ptr<Session> CreateSession() {
return std::unique_ptr<Session>(NewSession(DefaultSessionOptions()));
}
class RunHandlerTest : public ::testing::Test {
public:
void Initialize(std::initializer_list<float> a_values) {
Graph graph(OpRegistry::Global());
Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
test::FillValues<float>(&a_tensor, a_values);
Node* a = test::graph::Constant(&graph, a_tensor);
a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
a_ = a->name();
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
test::FillValues<float>(&x_tensor, {1, 1});
Node* x = test::graph::Constant(&graph, x_tensor);
x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
x_ = x->name();
// y = A * x
Node* y = test::graph::Matmul(&graph, a, x, false, false);
y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
y_ = y->name();
Node* y_neg = test::graph::Unary(&graph, "Neg", y);
y_neg_ = y_neg->name();
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
Node* z = test::graph::Unary(&graph, "Identity", y_neg);
z_ = z->name();
z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
graph.ToGraphDef(&def_);
ASSERT_EQ(setenv("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", "2", true), 0);
ASSERT_EQ(
setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "8,8", true),
0);
ASSERT_EQ(setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
"0,0.4", true),
0);
ASSERT_EQ(setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
"0.4,1", true),
0);
ASSERT_EQ(setenv("TF_NUM_INTEROP_THREADS", "16", true), 0);
}
string a_;
string x_;
string y_;
string y_neg_;
string z_;
GraphDef def_;
};
TEST_F(RunHandlerTest, UseRunHandlerPoolEnableSubPool) {
Initialize({3, 2, -1, 0});
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
EXPECT_EQ(::tensorflow::Status::OK(), session->Create(def_));
std::vector<std::pair<string, Tensor>> inputs;
// Request two targets: one fetch output and one non-fetched output.
std::vector<string> output_names = {y_ + ":0"};
std::vector<string> target_nodes = {y_neg_};
std::vector<Tensor> outputs;
// Prepares RunOptions and RunMetadata
RunOptions run_options;
run_options.mutable_experimental()->set_use_run_handler_pool(true);
Status s = session->Run(run_options, inputs, output_names, target_nodes,
&outputs, nullptr);
EXPECT_EQ(::tensorflow::Status::OK(), s);
ASSERT_EQ(1, outputs.size());
// The first output should be initialized and have the correct
// output.
auto mat = outputs[0].matrix<float>();
ASSERT_TRUE(outputs[0].IsInitialized());
EXPECT_FLOAT_EQ(5.0, mat(0, 0));
}
TEST_F(RunHandlerTest, TestConcurrencyUseRunHandlerPool) {
Initialize({1, 2, 3, 4});
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
EXPECT_EQ(::tensorflow::Status::OK(), session->Create(def_));
RunOptions run_options;
run_options.mutable_experimental()->set_use_run_handler_pool(true);
// Fill in the input and ask for the output
thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
// Run the graph 1000 times in 4 different threads concurrently.
std::vector<string> output_names = {y_ + ":0"};
auto fn = [&session, output_names, run_options]() {
for (int i = 0; i < 1000; ++i) {
std::vector<std::pair<string, Tensor>> inputs;
std::vector<Tensor> outputs;
// Run the graph
Status s = session->Run(run_options, inputs, output_names, {}, &outputs,
nullptr);
EXPECT_EQ(::tensorflow::Status::OK(), s);
ASSERT_EQ(1, outputs.size());
auto mat = outputs[0].matrix<float>();
EXPECT_FLOAT_EQ(3.0, mat(0, 0));
}
};
for (int i = 0; i < 4; ++i) {
tp->Schedule(fn);
}
// Wait for the functions to finish.
delete tp;
}
} // namespace
} // namespace tensorflow

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/str_util.h"
namespace tensorflow {
@ -29,6 +30,54 @@ double ParamFromEnvWithDefault(const std::string& var_name,
return (val && strings::safe_strtod(val, &num)) ? num : default_value;
}
std::vector<double> ParamFromEnvWithDefault(const std::string& var_name,
std::vector<double> default_value) {
const char* val = std::getenv(var_name.c_str());
if (!val) {
return default_value;
}
std::vector<string> splits = str_util::Split(val, ",");
std::vector<double> result;
result.reserve(splits.size());
for (auto& split : splits) {
double num;
if (strings::safe_strtod(split, &num)) {
result.push_back(num);
} else {
LOG(ERROR) << "Wrong format for " << var_name << ". Use default value.";
return default_value;
}
}
return result;
}
std::vector<int> ParamFromEnvWithDefault(const std::string& var_name,
std::vector<int> default_value) {
const char* val = std::getenv(var_name.c_str());
if (!val) {
return default_value;
}
std::vector<string> splits = str_util::Split(val, ",");
std::vector<int> result;
result.reserve(splits.size());
for (auto& split : splits) {
int num;
if (strings::safe_strto32(split, &num)) {
result.push_back(num);
} else {
LOG(ERROR) << "Wrong format for " << var_name << ". Use default value.";
return default_value;
}
}
return result;
}
bool ParamFromEnvBoolWithDefault(const std::string& var_name,
bool default_value) {
const char* val = std::getenv(var_name.c_str());
return (val) ? str_util::Lowercase(val) == "true" : default_value;
}
void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
int min_threads_per_request,
std::vector<std::uint_fast32_t>* start_vec,

View File

@ -54,10 +54,27 @@ void ComputeInterOpStealingRanges(int num_threads, int min_threads_per_domain,
std::vector<int> ChooseRequestsWithExponentialDistribution(
int num_active_requests, int num_threads);
// Loop environment variable named 'var_name' and return the value if it exist
// and can be parsed. Return 'default_value' otherwise.
// Look up environment variable named 'var_name' and return the value if it
// exist and can be parsed. Return 'default_value' otherwise.
double ParamFromEnvWithDefault(const std::string& var_name,
double default_value);
// Look up environment variable named 'var_name' and return the value if it
// exist and can be parsed. The value must be in format val1,val2... Return
// 'default_value' otherwise.
std::vector<double> ParamFromEnvWithDefault(const std::string& var_name,
std::vector<double> default_value);
// Look up environment variable named 'var_name' and return the value if it
// exist and can be parsed. The value must be in format val1,val2... Return
// 'default_value' otherwise.
std::vector<int> ParamFromEnvWithDefault(const std::string& var_name,
std::vector<int> default_value);
// Look up environment variable named 'var_name' and return the value if it
// exist and can be parsed. Return 'default_value' otherwise.
bool ParamFromEnvBoolWithDefault(const std::string& var_name,
bool default_value);
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_

View File

@ -124,5 +124,44 @@ TEST(RunHandlerUtilTest, TestExponentialRequestDistribution) {
ASSERT_EQ(actual_distribution, expected_distribution);
}
TEST(RunHandlerUtilTest, TestParamFromEnvWithDefault) {
std::vector<double> result = ParamFromEnvWithDefault(
"RUN_HANDLER_TEST_ENV", std::vector<double>{0, 0, 0});
EXPECT_EQ(result.size(), 3);
EXPECT_EQ(result[0], 0);
EXPECT_EQ(result[1], 0);
EXPECT_EQ(result[2], 0);
std::vector<int> result2 = ParamFromEnvWithDefault("RUN_HANDLER_TEST_ENV",
std::vector<int>{0, 0, 0});
EXPECT_EQ(result2.size(), 3);
EXPECT_EQ(result2[0], 0);
EXPECT_EQ(result2[1], 0);
EXPECT_EQ(result2[2], 0);
bool result3 =
ParamFromEnvBoolWithDefault("RUN_HANDLER_TEST_ENV_BOOL", false);
EXPECT_EQ(result3, false);
// Set environment variable.
EXPECT_EQ(setenv("RUN_HANDLER_TEST_ENV", "1,2,3", true), 0);
result = ParamFromEnvWithDefault("RUN_HANDLER_TEST_ENV",
std::vector<double>{0, 0, 0});
EXPECT_EQ(result.size(), 3);
EXPECT_EQ(result[0], 1);
EXPECT_EQ(result[1], 2);
EXPECT_EQ(result[2], 3);
result2 = ParamFromEnvWithDefault("RUN_HANDLER_TEST_ENV",
std::vector<int>{0, 0, 0});
EXPECT_EQ(result.size(), 3);
EXPECT_EQ(result2[0], 1);
EXPECT_EQ(result2[1], 2);
EXPECT_EQ(result2[2], 3);
EXPECT_EQ(setenv("RUN_HANDLER_TEST_ENV_BOOL", "true", true), 0);
result3 = ParamFromEnvBoolWithDefault("RUN_HANDLER_TEST_ENV_BOOL", false);
EXPECT_EQ(result3, true);
}
} // namespace
} // namespace tensorflow