Introduce sub thread pools to run handler thread pool.
PiperOrigin-RevId: 279237293 Change-Id: I8f1945105d27d1ed06eee8bb914bfaa06fec6c1f
This commit is contained in:
parent
af019188ad
commit
bfce5bae88
@ -4468,11 +4468,18 @@ tf_cc_test(
|
|||||||
srcs = ["framework/run_handler_test.cc"],
|
srcs = ["framework/run_handler_test.cc"],
|
||||||
linkstatic = tf_kernel_tests_linkstatic(),
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
deps = [
|
deps = [
|
||||||
|
":core_cpu",
|
||||||
|
":direct_session_internal",
|
||||||
":framework_internal",
|
":framework_internal",
|
||||||
":lib",
|
":lib",
|
||||||
":lib_internal",
|
":lib_internal",
|
||||||
|
":protos_all_cc",
|
||||||
|
":tensor_testutil",
|
||||||
":test",
|
":test",
|
||||||
":test_main",
|
":test_main",
|
||||||
|
":testlib",
|
||||||
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
|
"//tensorflow/core/kernels:matmul_op",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
|
@ -98,6 +98,55 @@ class RunHandlerEnvironment {
|
|||||||
typedef typename RunHandlerEnvironment::Task Task;
|
typedef typename RunHandlerEnvironment::Task Task;
|
||||||
typedef Eigen::RunQueue<Task, 1024> Queue;
|
typedef Eigen::RunQueue<Task, 1024> Queue;
|
||||||
|
|
||||||
|
// To reduce cache misses, we use a doubly-linked list of Waiter structs and
|
||||||
|
// queue them in LIFO order rather than the FIFO order used by a single
|
||||||
|
// condition variable.
|
||||||
|
struct Waiter {
|
||||||
|
Waiter() {
|
||||||
|
next = this;
|
||||||
|
prev = this;
|
||||||
|
}
|
||||||
|
condition_variable cv;
|
||||||
|
mutex mu;
|
||||||
|
Waiter* next;
|
||||||
|
Waiter* prev;
|
||||||
|
};
|
||||||
|
|
||||||
|
void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
|
||||||
|
int max_sleep_micros) {
|
||||||
|
{
|
||||||
|
mutex_lock l(*mutex);
|
||||||
|
CHECK_EQ(waiter->next, waiter); // Crash OK.
|
||||||
|
CHECK_EQ(waiter->prev, waiter); // Crash OK.
|
||||||
|
|
||||||
|
// Add waiter to the LIFO queue
|
||||||
|
waiter->prev = queue_head;
|
||||||
|
waiter->next = queue_head->next;
|
||||||
|
waiter->next->prev = waiter;
|
||||||
|
waiter->prev->next = waiter;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
mutex_lock l(waiter->mu);
|
||||||
|
// Wait on the condition variable
|
||||||
|
waiter->cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
|
||||||
|
}
|
||||||
|
|
||||||
|
mutex_lock l(*mutex);
|
||||||
|
// Remove waiter from the LIFO queue. Note even when a waiter wakes up due
|
||||||
|
// to a notification we cannot conclude the waiter is not in the queue.
|
||||||
|
// This is due to the fact that a thread preempted right before notifying
|
||||||
|
// may resume after a waiter got re-added.
|
||||||
|
if (waiter->next != waiter) {
|
||||||
|
CHECK(waiter->prev != waiter); // Crash OK.
|
||||||
|
waiter->next->prev = waiter->prev;
|
||||||
|
waiter->prev->next = waiter->next;
|
||||||
|
waiter->next = waiter;
|
||||||
|
waiter->prev = waiter;
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(waiter->prev, waiter); // Crash OK.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class ThreadWorkSource {
|
class ThreadWorkSource {
|
||||||
public:
|
public:
|
||||||
ThreadWorkSource()
|
ThreadWorkSource()
|
||||||
@ -155,11 +204,32 @@ class ThreadWorkSource {
|
|||||||
if (max_rank_to_wakeup > 0 &&
|
if (max_rank_to_wakeup > 0 &&
|
||||||
rank_.load(std::memory_order_relaxed) <= max_rank_to_wakeup) {
|
rank_.load(std::memory_order_relaxed) <= max_rank_to_wakeup) {
|
||||||
Waiter* w = nullptr;
|
Waiter* w = nullptr;
|
||||||
|
bool use_sub_thread_pool = ParamFromEnvBoolWithDefault(
|
||||||
|
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
|
||||||
|
|
||||||
|
Waiter* waiter_queue;
|
||||||
|
mutex* waiter_queue_mu;
|
||||||
|
if (use_sub_thread_pool) {
|
||||||
|
// When we use multiple sub thread pools, free threads wait on sub
|
||||||
|
// thread pool waiting queues. Wake up threads from sub thread waiting
|
||||||
|
// queues.
|
||||||
|
// The waiting queues are defined at RunHandlerPool.
|
||||||
|
// Get the waiter_queue and coresponding mutex. Note, the thread work
|
||||||
|
// source may change afterwards if a new request comes or an old request
|
||||||
|
// finishes.
|
||||||
|
tf_shared_lock lock(run_handler_waiter_mu_);
|
||||||
|
waiter_queue = sub_thread_pool_waiter_;
|
||||||
|
waiter_queue_mu = sub_thread_pool_waiter_mu_;
|
||||||
|
} else {
|
||||||
|
waiter_queue = &queue_waiters_;
|
||||||
|
waiter_queue_mu = &waiters_mu_;
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
mutex_lock l(waiters_mu_);
|
mutex_lock l(*waiter_queue_mu);
|
||||||
if (queue_waiters_.next != &queue_waiters_) {
|
if (waiter_queue->next != waiter_queue) {
|
||||||
// Remove waiter from the LIFO queue
|
// Remove waiter from the LIFO queue
|
||||||
w = queue_waiters_.next;
|
w = waiter_queue->next;
|
||||||
|
|
||||||
CHECK(w->prev != w);
|
CHECK(w->prev != w);
|
||||||
CHECK(w->next != w);
|
CHECK(w->next != w);
|
||||||
@ -187,43 +257,25 @@ class ThreadWorkSource {
|
|||||||
|
|
||||||
Task PopBlockingTask() { return blocking_work_queue_.PopBack(); }
|
Task PopBlockingTask() { return blocking_work_queue_.PopBack(); }
|
||||||
|
|
||||||
Task PopNonBlockingTask(int index) {
|
Task PopNonBlockingTask(int start_index, bool search_from_all_queue) {
|
||||||
return non_blocking_work_queues_[index]->queue.PopBack();
|
Task t;
|
||||||
|
unsigned sharding_factor = NonBlockingWorkShardingFactor();
|
||||||
|
for (unsigned j = 0; j < sharding_factor; ++j) {
|
||||||
|
t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
|
||||||
|
->queue.PopBack();
|
||||||
|
if (t.f) {
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
if (!search_from_all_queue) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
void WaitForWork(int max_sleep_micros) {
|
void WaitForWork(int max_sleep_micros) {
|
||||||
thread_local Waiter waiter;
|
thread_local Waiter waiter;
|
||||||
{
|
WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
|
||||||
mutex_lock l(waiters_mu_);
|
|
||||||
CHECK_EQ(waiter.next, &waiter);
|
|
||||||
CHECK_EQ(waiter.prev, &waiter);
|
|
||||||
|
|
||||||
// Add waiter to the LIFO queue
|
|
||||||
waiter.prev = &queue_waiters_;
|
|
||||||
waiter.next = queue_waiters_.next;
|
|
||||||
waiter.next->prev = &waiter;
|
|
||||||
waiter.prev->next = &waiter;
|
|
||||||
}
|
|
||||||
{
|
|
||||||
mutex_lock l(waiter.mu);
|
|
||||||
// Wait on the condition variable
|
|
||||||
waiter.cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
|
|
||||||
}
|
|
||||||
|
|
||||||
mutex_lock l(waiters_mu_);
|
|
||||||
// Remove waiter from the LIFO queue. Note even when a waiter wakes up due
|
|
||||||
// to a notification we cannot conclude the waiter is not in the queue.
|
|
||||||
// This is due to the fact that a thread preempted right before notifying
|
|
||||||
// may resume after a waiter got re-added.
|
|
||||||
if (waiter.next != &waiter) {
|
|
||||||
CHECK(waiter.prev != &waiter);
|
|
||||||
waiter.next->prev = waiter.prev;
|
|
||||||
waiter.prev->next = waiter.next;
|
|
||||||
waiter.next = &waiter;
|
|
||||||
waiter.prev = &waiter;
|
|
||||||
} else {
|
|
||||||
CHECK_EQ(waiter.prev, &waiter);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int TaskQueueSize(bool is_blocking) {
|
int TaskQueueSize(bool is_blocking) {
|
||||||
@ -243,6 +295,12 @@ class ThreadWorkSource {
|
|||||||
void SetTracemeId(int64 value) { traceme_id_ = value; }
|
void SetTracemeId(int64 value) { traceme_id_ = value; }
|
||||||
void SetRank(int64 value) { rank_ = value; }
|
void SetRank(int64 value) { rank_ = value; }
|
||||||
|
|
||||||
|
void SetWaiter(Waiter* waiter, mutex* mutex) {
|
||||||
|
mutex_lock l(run_handler_waiter_mu_);
|
||||||
|
sub_thread_pool_waiter_ = waiter;
|
||||||
|
sub_thread_pool_waiter_mu_ = mutex;
|
||||||
|
}
|
||||||
|
|
||||||
int64 GetInflightTaskCount(bool is_blocking) {
|
int64 GetInflightTaskCount(bool is_blocking) {
|
||||||
std::atomic<int64>* counter =
|
std::atomic<int64>* counter =
|
||||||
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
|
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
|
||||||
@ -274,20 +332,6 @@ class ThreadWorkSource {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// To reduce cache misses, we use a doubly-linked list of Waiter structs and
|
|
||||||
// queue them in LIFO order rather than the FIFO order used by a single
|
|
||||||
// condition variable.
|
|
||||||
struct Waiter {
|
|
||||||
Waiter() {
|
|
||||||
next = this;
|
|
||||||
prev = this;
|
|
||||||
}
|
|
||||||
condition_variable cv;
|
|
||||||
mutex mu;
|
|
||||||
Waiter* next;
|
|
||||||
Waiter* prev;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct NonBlockingQueue {
|
struct NonBlockingQueue {
|
||||||
mutex queue_op_mu;
|
mutex queue_op_mu;
|
||||||
char pad[128];
|
char pad[128];
|
||||||
@ -307,6 +351,10 @@ class ThreadWorkSource {
|
|||||||
Waiter queue_waiters_ GUARDED_BY(waiters_mu_);
|
Waiter queue_waiters_ GUARDED_BY(waiters_mu_);
|
||||||
std::atomic<int64> traceme_id_;
|
std::atomic<int64> traceme_id_;
|
||||||
std::atomic<int64> rank_;
|
std::atomic<int64> rank_;
|
||||||
|
|
||||||
|
mutex run_handler_waiter_mu_;
|
||||||
|
mutex* sub_thread_pool_waiter_mu_ GUARDED_BY(run_handler_waiter_mu_);
|
||||||
|
Waiter* sub_thread_pool_waiter_ GUARDED_BY(run_handler_waiter_mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
class RunHandlerThreadPool {
|
class RunHandlerThreadPool {
|
||||||
@ -319,25 +367,33 @@ class RunHandlerThreadPool {
|
|||||||
|
|
||||||
RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
|
RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
|
||||||
Env* env, const ThreadOptions& thread_options,
|
Env* env, const ThreadOptions& thread_options,
|
||||||
const string& name)
|
const string& name,
|
||||||
|
Eigen::MaxSizeVector<mutex>* waiters_mu,
|
||||||
|
Eigen::MaxSizeVector<Waiter>* queue_waiters)
|
||||||
: num_threads_(num_blocking_threads + num_non_blocking_threads),
|
: num_threads_(num_blocking_threads + num_non_blocking_threads),
|
||||||
num_blocking_threads_(num_blocking_threads),
|
num_blocking_threads_(num_blocking_threads),
|
||||||
num_non_blocking_threads_(num_non_blocking_threads),
|
num_non_blocking_threads_(num_non_blocking_threads),
|
||||||
thread_data_(num_threads_),
|
thread_data_(num_threads_),
|
||||||
env_(env, thread_options, name),
|
env_(env, thread_options, name),
|
||||||
name_(name) {
|
name_(name),
|
||||||
|
waiters_mu_(waiters_mu),
|
||||||
|
queue_waiters_(queue_waiters),
|
||||||
|
use_sub_thread_pool_(ParamFromEnvBoolWithDefault(
|
||||||
|
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
|
||||||
|
num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
|
||||||
|
"TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
|
||||||
|
std::vector<int>(
|
||||||
|
{num_blocking_threads / 2,
|
||||||
|
num_blocking_threads - num_blocking_threads / 2}))),
|
||||||
|
sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
|
||||||
|
"TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
|
||||||
|
std::vector<double>({0, 0.4}))),
|
||||||
|
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
|
||||||
|
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
|
||||||
|
std::vector<double>({0.4, 1}))) {
|
||||||
VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
|
VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
|
||||||
<< num_blocking_threads_ << " blocking threads and "
|
<< num_blocking_threads_ << " blocking threads and "
|
||||||
<< num_non_blocking_threads_ << " non-blocking threads.";
|
<< num_non_blocking_threads_ << " non-blocking threads.";
|
||||||
cancelled_ = false;
|
|
||||||
|
|
||||||
thread_data_.resize(num_threads_);
|
|
||||||
for (int i = 0; i < num_threads_; i++) {
|
|
||||||
thread_data_[i].thread.reset(
|
|
||||||
env_.CreateThread([this, i, num_blocking_threads]() {
|
|
||||||
WorkerLoop(i, i < num_blocking_threads);
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
~RunHandlerThreadPool() {
|
~RunHandlerThreadPool() {
|
||||||
@ -353,6 +409,26 @@ class RunHandlerThreadPool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Start() {
|
||||||
|
cancelled_ = false;
|
||||||
|
thread_data_.resize(num_threads_);
|
||||||
|
int num_blocking_threads = num_blocking_threads_;
|
||||||
|
for (int i = 0; i < num_threads_; i++) {
|
||||||
|
int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
|
||||||
|
for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
|
||||||
|
if (i < num_threads_in_sub_thread_pool_[j]) {
|
||||||
|
sub_thread_pool_id = j;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
|
||||||
|
thread_data_[i].thread.reset(
|
||||||
|
env_.CreateThread([this, i, num_blocking_threads]() {
|
||||||
|
WorkerLoop(i, i < num_blocking_threads);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
|
void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
|
||||||
std::function<void()> fn) {
|
std::function<void()> fn) {
|
||||||
Task t = env_.CreateTask(std::move(fn));
|
Task t = env_.CreateTask(std::move(fn));
|
||||||
@ -384,19 +460,25 @@ class RunHandlerThreadPool {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
thread_data_[tid].thread_work_sources.resize(0);
|
thread_data_[tid].thread_work_sources.resize(0);
|
||||||
|
|
||||||
|
if (use_sub_thread_pool_) {
|
||||||
|
for (int i = 0; i < thread_work_sources.size(); ++i) {
|
||||||
|
thread_data_[tid].thread_work_sources.emplace_back(
|
||||||
|
thread_work_sources[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
thread_data_[tid].thread_work_sources.emplace_back(
|
thread_data_[tid].thread_work_sources.emplace_back(
|
||||||
thread_work_sources[start_request_idx]);
|
thread_work_sources[start_request_idx]);
|
||||||
// The number of shards for the queue. Threads in each shard will prioritize
|
// The number of shards for the queue. Threads in each shard will
|
||||||
// different thread_work_sources. Increase the number of shards could
|
// prioritize different thread_work_sources. Increase the number of shards
|
||||||
// decrease the contention in the queue.
|
// could decrease the contention in the queue. For example, when
|
||||||
// For example, when num_shards == 1:
|
// num_shards == 1: thread_work_sources are ordered as start_request_idx,
|
||||||
// thread_work_sources are ordered as start_request_idx, 0, 1, 2, 3, 4 ...
|
// 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2:
|
||||||
// for all threads.
|
|
||||||
// When num_shards == 2:
|
|
||||||
// thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
|
// thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
|
||||||
// 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
|
// 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
|
||||||
// 4... for the other half of the threads.
|
// 4... for the other half of the threads.
|
||||||
int num_shards = ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
|
int num_shards =
|
||||||
|
ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
|
||||||
int token = tid % num_shards;
|
int token = tid % num_shards;
|
||||||
for (int i = 0; i < num_shards; ++i) {
|
for (int i = 0; i < num_shards; ++i) {
|
||||||
for (int j = token; j < thread_work_sources.size(); j += num_shards) {
|
for (int j = token; j < thread_work_sources.size(); j += num_shards) {
|
||||||
@ -409,6 +491,7 @@ class RunHandlerThreadPool {
|
|||||||
}
|
}
|
||||||
thread_data_[tid].sources_not_empty.notify_all();
|
thread_data_[tid].sources_not_empty.notify_all();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
PerThread* GetPerThread() {
|
PerThread* GetPerThread() {
|
||||||
thread_local PerThread per_thread_;
|
thread_local PerThread per_thread_;
|
||||||
@ -434,13 +517,26 @@ class RunHandlerThreadPool {
|
|||||||
|
|
||||||
void WorkerLoop(int thread_id, bool may_steal_blocking_work);
|
void WorkerLoop(int thread_id, bool may_steal_blocking_work);
|
||||||
|
|
||||||
|
// Search tasks from Requets range searching_range_start to
|
||||||
|
// searching_range_end. If there is no tasks in the search range and
|
||||||
|
// may_steal_blocking_work is true, then search from all reuqests.
|
||||||
|
Task FindTask(
|
||||||
|
int searching_range_start, int searching_range_end, int thread_id,
|
||||||
|
int sub_thread_pool_id, int max_blocking_inflight,
|
||||||
|
bool may_steal_blocking_work,
|
||||||
|
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
|
||||||
|
bool* task_from_blocking_queue, ThreadWorkSource** tws);
|
||||||
|
|
||||||
void WaitForWork(bool is_blocking, int thread_id,
|
void WaitForWork(bool is_blocking, int thread_id,
|
||||||
int32 max_blocking_inflight);
|
int32 max_blocking_inflight);
|
||||||
|
|
||||||
|
void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct ThreadData {
|
struct ThreadData {
|
||||||
ThreadData()
|
ThreadData()
|
||||||
: version(0),
|
: version(0),
|
||||||
|
current_index(0),
|
||||||
thread_work_sources(static_cast<int32>(
|
thread_work_sources(static_cast<int32>(
|
||||||
ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
|
||||||
kMaxConcurrentHandlers))) {}
|
kMaxConcurrentHandlers))) {}
|
||||||
@ -448,7 +544,9 @@ class RunHandlerThreadPool {
|
|||||||
uint64 version;
|
uint64 version;
|
||||||
condition_variable sources_not_empty;
|
condition_variable sources_not_empty;
|
||||||
std::unique_ptr<Thread> thread;
|
std::unique_ptr<Thread> thread;
|
||||||
|
int current_index;
|
||||||
Eigen::MaxSizeVector<ThreadWorkSource*> thread_work_sources GUARDED_BY(mu);
|
Eigen::MaxSizeVector<ThreadWorkSource*> thread_work_sources GUARDED_BY(mu);
|
||||||
|
int sub_thread_pool_id;
|
||||||
};
|
};
|
||||||
|
|
||||||
const int num_threads_;
|
const int num_threads_;
|
||||||
@ -458,8 +556,58 @@ class RunHandlerThreadPool {
|
|||||||
RunHandlerEnvironment env_;
|
RunHandlerEnvironment env_;
|
||||||
std::atomic<bool> cancelled_;
|
std::atomic<bool> cancelled_;
|
||||||
string name_;
|
string name_;
|
||||||
|
Eigen::MaxSizeVector<mutex>* waiters_mu_;
|
||||||
|
Eigen::MaxSizeVector<Waiter>* queue_waiters_;
|
||||||
|
|
||||||
|
bool use_sub_thread_pool_;
|
||||||
|
std::vector<int> num_threads_in_sub_thread_pool_;
|
||||||
|
|
||||||
|
// Threads in each sub thread pool will search tasks from the given
|
||||||
|
// start_request_percentage to end_request_percentage in a round robin
|
||||||
|
// fashion.
|
||||||
|
std::vector<double> sub_thread_pool_start_request_percentage_;
|
||||||
|
std::vector<double> sub_thread_pool_end_request_percentage_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Task RunHandlerThreadPool::FindTask(
|
||||||
|
int searching_range_start, int searching_range_end, int thread_id,
|
||||||
|
int sub_thread_pool_id, int max_blocking_inflight,
|
||||||
|
bool may_steal_blocking_work,
|
||||||
|
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
|
||||||
|
bool* task_from_blocking_queue, ThreadWorkSource** tws) {
|
||||||
|
Task t;
|
||||||
|
int current_index = thread_data_[thread_id].current_index;
|
||||||
|
*task_from_blocking_queue = false;
|
||||||
|
|
||||||
|
// TODO(chaox): Chagne the search algorithm from round robin to random
|
||||||
|
// walk.
|
||||||
|
for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
|
||||||
|
if (current_index >= searching_range_end) {
|
||||||
|
current_index = searching_range_start;
|
||||||
|
}
|
||||||
|
*tws = thread_work_sources[current_index];
|
||||||
|
++current_index;
|
||||||
|
|
||||||
|
// For blocking thread, search for blocking tasks first.
|
||||||
|
if (may_steal_blocking_work &&
|
||||||
|
(*tws)->GetInflightTaskCount(true) < max_blocking_inflight) {
|
||||||
|
t = (*tws)->PopBlockingTask();
|
||||||
|
if (t.f) {
|
||||||
|
*task_from_blocking_queue = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search for non-blocking tasks.
|
||||||
|
t = (*tws)->PopNonBlockingTask(thread_id, true);
|
||||||
|
if (t.f) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
thread_data_[thread_id].current_index = current_index;
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
// Main worker thread loop.
|
// Main worker thread loop.
|
||||||
void RunHandlerThreadPool::WorkerLoop(int thread_id,
|
void RunHandlerThreadPool::WorkerLoop(int thread_id,
|
||||||
bool may_steal_blocking_work) {
|
bool may_steal_blocking_work) {
|
||||||
@ -474,11 +622,47 @@ void RunHandlerThreadPool::WorkerLoop(int thread_id,
|
|||||||
bool task_from_blocking_queue = true;
|
bool task_from_blocking_queue = true;
|
||||||
Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
|
Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
|
||||||
&thread_data_[thread_id].thread_work_sources;
|
&thread_data_[thread_id].thread_work_sources;
|
||||||
{
|
int sub_thread_pool_id;
|
||||||
|
if (use_sub_thread_pool_) {
|
||||||
|
// The mutex is not hot since its per thread and can only be held
|
||||||
|
// by some other thread when a session run starts/finishes.
|
||||||
|
mutex_lock l(thread_data_[thread_id].mu);
|
||||||
|
sub_thread_pool_id = thread_data_[thread_id].sub_thread_pool_id;
|
||||||
|
int active_requests = thread_work_sources->size();
|
||||||
|
if (may_steal_blocking_work) {
|
||||||
|
// Each thread will first look for tasks from requests that belongs to
|
||||||
|
// its sub thread pool.
|
||||||
|
t = FindTask(
|
||||||
|
active_requests *
|
||||||
|
sub_thread_pool_start_request_percentage_[sub_thread_pool_id],
|
||||||
|
active_requests *
|
||||||
|
sub_thread_pool_end_request_percentage_[sub_thread_pool_id],
|
||||||
|
thread_id, sub_thread_pool_id, kMaxBlockingInflight,
|
||||||
|
/*may_steal_blocking_work=*/true, *thread_work_sources,
|
||||||
|
&task_from_blocking_queue, &tws);
|
||||||
|
if (!t.f) {
|
||||||
|
// Search from all requests if the thread cannot find tasks from
|
||||||
|
// requests that belong to its own sub thread pool.
|
||||||
|
t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
|
||||||
|
kMaxBlockingInflight,
|
||||||
|
/*may_steal_blocking_work=*/true, *thread_work_sources,
|
||||||
|
&task_from_blocking_queue, &tws);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For non-blocking threads, it will always search from all pending
|
||||||
|
// requests.
|
||||||
|
t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
|
||||||
|
kMaxBlockingInflight,
|
||||||
|
/*may_steal_blocking_work=*/false, *thread_work_sources,
|
||||||
|
&task_from_blocking_queue, &tws);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
// The mutex is not hot since its per thread and can only be held
|
// The mutex is not hot since its per thread and can only be held
|
||||||
// by some other thread when a session run starts/finishes.
|
// by some other thread when a session run starts/finishes.
|
||||||
mutex_lock l(thread_data_[thread_id].mu);
|
mutex_lock l(thread_data_[thread_id].mu);
|
||||||
|
|
||||||
|
// TODO(chaox): Refactor the following code to share the logic with
|
||||||
|
// FindTask.
|
||||||
for (int i = 0; i < thread_work_sources->size(); ++i) {
|
for (int i = 0; i < thread_work_sources->size(); ++i) {
|
||||||
tws = (*thread_work_sources)[i];
|
tws = (*thread_work_sources)[i];
|
||||||
// We want a smallish numbers of inter threads since
|
// We want a smallish numbers of inter threads since
|
||||||
@ -495,20 +679,16 @@ void RunHandlerThreadPool::WorkerLoop(int thread_id,
|
|||||||
// Always look for any work from the "primary" work source.
|
// Always look for any work from the "primary" work source.
|
||||||
// This way when we wake up a thread for a new closure we are
|
// This way when we wake up a thread for a new closure we are
|
||||||
// guaranteed it can be worked on.
|
// guaranteed it can be worked on.
|
||||||
for (int j = 0; j < tws->NonBlockingWorkShardingFactor(); ++j) {
|
t = tws->PopNonBlockingTask(thread_id, true);
|
||||||
t = tws->PopNonBlockingTask((j + thread_id) %
|
|
||||||
tws->NonBlockingWorkShardingFactor());
|
|
||||||
if (t.f) {
|
if (t.f) {
|
||||||
task_from_blocking_queue = false;
|
task_from_blocking_queue = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if (t.f) {
|
if (t.f) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t = tws->PopNonBlockingTask(thread_id %
|
t = tws->PopNonBlockingTask(thread_id, false);
|
||||||
tws->NonBlockingWorkShardingFactor());
|
|
||||||
if (t.f) {
|
if (t.f) {
|
||||||
task_from_blocking_queue = false;
|
task_from_blocking_queue = false;
|
||||||
break;
|
break;
|
||||||
@ -542,11 +722,29 @@ void RunHandlerThreadPool::WorkerLoop(int thread_id,
|
|||||||
<< (*thread_work_sources)[i]->ToString();
|
<< (*thread_work_sources)[i]->ToString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (use_sub_thread_pool_) {
|
||||||
|
WaitForWorkInSubThreadPool(may_steal_blocking_work, sub_thread_pool_id);
|
||||||
|
} else {
|
||||||
WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight);
|
WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunHandlerThreadPool::WaitForWorkInSubThreadPool(bool is_blocking,
|
||||||
|
int sub_thread_pool_id) {
|
||||||
|
const int kMaxSleepMicros = 250;
|
||||||
|
|
||||||
|
// The non-blocking thread will just sleep.
|
||||||
|
if (!is_blocking) {
|
||||||
|
Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_local Waiter waiter;
|
||||||
|
WaitOnWaiter(&waiter, &(*queue_waiters_)[sub_thread_pool_id],
|
||||||
|
&(*waiters_mu_)[sub_thread_pool_id], kMaxSleepMicros);
|
||||||
|
}
|
||||||
|
|
||||||
void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id,
|
void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id,
|
||||||
int32 max_blocking_inflight) {
|
int32 max_blocking_inflight) {
|
||||||
@ -636,16 +834,33 @@ class RunHandlerPool::Impl {
|
|||||||
explicit Impl(int num_inter_op_threads, int num_intra_op_threads)
|
explicit Impl(int num_inter_op_threads, int num_intra_op_threads)
|
||||||
: max_handlers_(static_cast<int32>(ParamFromEnvWithDefault(
|
: max_handlers_(static_cast<int32>(ParamFromEnvWithDefault(
|
||||||
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", kMaxConcurrentHandlers))),
|
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", kMaxConcurrentHandlers))),
|
||||||
|
waiters_mu_(
|
||||||
|
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
|
||||||
|
queue_waiters_(
|
||||||
|
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
|
||||||
run_handler_thread_pool_(new RunHandlerThreadPool(
|
run_handler_thread_pool_(new RunHandlerThreadPool(
|
||||||
num_inter_op_threads, num_intra_op_threads, Env::Default(),
|
num_inter_op_threads, num_intra_op_threads, Env::Default(),
|
||||||
ThreadOptions(), "tf_run_handler_pool")),
|
ThreadOptions(), "tf_run_handler_pool", &waiters_mu_,
|
||||||
|
&queue_waiters_)),
|
||||||
iterations_(0),
|
iterations_(0),
|
||||||
version_(0) {
|
version_(0),
|
||||||
|
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
|
||||||
|
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
|
||||||
|
std::vector<double>({1}))) {
|
||||||
VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
|
VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
|
||||||
for (int i = 0; i < max_handlers_; ++i) {
|
for (int i = 0; i < max_handlers_; ++i) {
|
||||||
handlers_.emplace_back(new RunHandler::Impl(this));
|
handlers_.emplace_back(new RunHandler::Impl(this));
|
||||||
free_handlers_.push_back(handlers_.back().get());
|
free_handlers_.push_back(handlers_.back().get());
|
||||||
}
|
}
|
||||||
|
queue_waiters_.resize(
|
||||||
|
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
|
||||||
|
waiters_mu_.resize(
|
||||||
|
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
|
||||||
|
for (auto& queue_waiter : queue_waiters_) {
|
||||||
|
queue_waiter.next = &queue_waiter;
|
||||||
|
queue_waiter.prev = &queue_waiter;
|
||||||
|
}
|
||||||
|
run_handler_thread_pool_->Start();
|
||||||
}
|
}
|
||||||
|
|
||||||
~Impl() {
|
~Impl() {
|
||||||
@ -693,6 +908,19 @@ class RunHandlerPool::Impl {
|
|||||||
for (int i = 0; i < num_active_requests; ++i) {
|
for (int i = 0; i < num_active_requests; ++i) {
|
||||||
(*thread_work_sources)[i] = sorted_active_handlers_[i]->tws();
|
(*thread_work_sources)[i] = sorted_active_handlers_[i]->tws();
|
||||||
(*thread_work_sources)[i]->SetRank(i);
|
(*thread_work_sources)[i]->SetRank(i);
|
||||||
|
int sub_thread_pool_id =
|
||||||
|
sub_thread_pool_end_request_percentage_.size() - 1;
|
||||||
|
for (int j = 0; j < sub_thread_pool_end_request_percentage_.size();
|
||||||
|
++j) {
|
||||||
|
if (i < num_active_requests *
|
||||||
|
sub_thread_pool_end_request_percentage_[j]) {
|
||||||
|
sub_thread_pool_id = j;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(*thread_work_sources)[i]->SetWaiter(
|
||||||
|
&queue_waiters_[sub_thread_pool_id],
|
||||||
|
&waiters_mu_[sub_thread_pool_id]);
|
||||||
}
|
}
|
||||||
version = ++version_;
|
version = ++version_;
|
||||||
}
|
}
|
||||||
@ -738,6 +966,19 @@ class RunHandlerPool::Impl {
|
|||||||
for (int i = 0; i < num_active_requests; ++i) {
|
for (int i = 0; i < num_active_requests; ++i) {
|
||||||
(*thread_work_sources)[i] = sorted_active_handlers_[i]->tws();
|
(*thread_work_sources)[i] = sorted_active_handlers_[i]->tws();
|
||||||
(*thread_work_sources)[i]->SetRank(i);
|
(*thread_work_sources)[i]->SetRank(i);
|
||||||
|
int sub_thread_pool_id =
|
||||||
|
sub_thread_pool_end_request_percentage_.size() - 1;
|
||||||
|
for (int j = 0; j < sub_thread_pool_end_request_percentage_.size();
|
||||||
|
++j) {
|
||||||
|
if (i < num_active_requests *
|
||||||
|
sub_thread_pool_end_request_percentage_[j]) {
|
||||||
|
sub_thread_pool_id = j;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(*thread_work_sources)[i]->SetWaiter(
|
||||||
|
&queue_waiters_[sub_thread_pool_id],
|
||||||
|
&waiters_mu_[sub_thread_pool_id]);
|
||||||
}
|
}
|
||||||
version = ++version_;
|
version = ++version_;
|
||||||
LogInfo();
|
LogInfo();
|
||||||
@ -759,6 +1000,9 @@ class RunHandlerPool::Impl {
|
|||||||
// inference).
|
// inference).
|
||||||
const int max_handlers_;
|
const int max_handlers_;
|
||||||
|
|
||||||
|
Eigen::MaxSizeVector<mutex> waiters_mu_;
|
||||||
|
Eigen::MaxSizeVector<Waiter> queue_waiters_;
|
||||||
|
|
||||||
std::unique_ptr<RunHandlerThreadPool> run_handler_thread_pool_;
|
std::unique_ptr<RunHandlerThreadPool> run_handler_thread_pool_;
|
||||||
// Thread compatible part used only by lock under RunHandlerPool.
|
// Thread compatible part used only by lock under RunHandlerPool.
|
||||||
// Handlers are sorted by start time.
|
// Handlers are sorted by start time.
|
||||||
@ -773,6 +1017,7 @@ class RunHandlerPool::Impl {
|
|||||||
condition_variable one_handler_free_;
|
condition_variable one_handler_free_;
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
int64 version_ GUARDED_BY(mu_);
|
int64 version_ GUARDED_BY(mu_);
|
||||||
|
const std::vector<double> sub_thread_pool_end_request_percentage_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void RunHandlerPool::Impl::RecomputePoolStats(
|
void RunHandlerPool::Impl::RecomputePoolStats(
|
||||||
|
@ -24,11 +24,17 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/synchronization/barrier.h"
|
#include "absl/synchronization/barrier.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/graph/testlib.h"
|
||||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -72,5 +78,132 @@ TEST(RunHandlerUtilTest, TestBasicScheduling) {
|
|||||||
counter.Wait();
|
counter.Wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SessionOptions DefaultSessionOptions() {
|
||||||
|
SessionOptions options;
|
||||||
|
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<Session> CreateSession() {
|
||||||
|
return std::unique_ptr<Session>(NewSession(DefaultSessionOptions()));
|
||||||
|
}
|
||||||
|
|
||||||
|
class RunHandlerTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
void Initialize(std::initializer_list<float> a_values) {
|
||||||
|
Graph graph(OpRegistry::Global());
|
||||||
|
|
||||||
|
Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
|
||||||
|
test::FillValues<float>(&a_tensor, a_values);
|
||||||
|
Node* a = test::graph::Constant(&graph, a_tensor);
|
||||||
|
a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
|
||||||
|
a_ = a->name();
|
||||||
|
|
||||||
|
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
|
||||||
|
test::FillValues<float>(&x_tensor, {1, 1});
|
||||||
|
Node* x = test::graph::Constant(&graph, x_tensor);
|
||||||
|
x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
|
||||||
|
x_ = x->name();
|
||||||
|
|
||||||
|
// y = A * x
|
||||||
|
Node* y = test::graph::Matmul(&graph, a, x, false, false);
|
||||||
|
y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
|
||||||
|
y_ = y->name();
|
||||||
|
|
||||||
|
Node* y_neg = test::graph::Unary(&graph, "Neg", y);
|
||||||
|
y_neg_ = y_neg->name();
|
||||||
|
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
|
||||||
|
|
||||||
|
Node* z = test::graph::Unary(&graph, "Identity", y_neg);
|
||||||
|
z_ = z->name();
|
||||||
|
z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
|
||||||
|
|
||||||
|
graph.ToGraphDef(&def_);
|
||||||
|
|
||||||
|
ASSERT_EQ(setenv("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", "2", true), 0);
|
||||||
|
ASSERT_EQ(
|
||||||
|
setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "8,8", true),
|
||||||
|
0);
|
||||||
|
ASSERT_EQ(setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
|
||||||
|
"0,0.4", true),
|
||||||
|
0);
|
||||||
|
ASSERT_EQ(setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
|
||||||
|
"0.4,1", true),
|
||||||
|
0);
|
||||||
|
ASSERT_EQ(setenv("TF_NUM_INTEROP_THREADS", "16", true), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
string a_;
|
||||||
|
string x_;
|
||||||
|
string y_;
|
||||||
|
string y_neg_;
|
||||||
|
string z_;
|
||||||
|
GraphDef def_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(RunHandlerTest, UseRunHandlerPoolEnableSubPool) {
|
||||||
|
Initialize({3, 2, -1, 0});
|
||||||
|
auto session = CreateSession();
|
||||||
|
ASSERT_TRUE(session != nullptr);
|
||||||
|
EXPECT_EQ(::tensorflow::Status::OK(), session->Create(def_));
|
||||||
|
std::vector<std::pair<string, Tensor>> inputs;
|
||||||
|
|
||||||
|
// Request two targets: one fetch output and one non-fetched output.
|
||||||
|
std::vector<string> output_names = {y_ + ":0"};
|
||||||
|
std::vector<string> target_nodes = {y_neg_};
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
|
||||||
|
// Prepares RunOptions and RunMetadata
|
||||||
|
RunOptions run_options;
|
||||||
|
run_options.mutable_experimental()->set_use_run_handler_pool(true);
|
||||||
|
|
||||||
|
Status s = session->Run(run_options, inputs, output_names, target_nodes,
|
||||||
|
&outputs, nullptr);
|
||||||
|
EXPECT_EQ(::tensorflow::Status::OK(), s);
|
||||||
|
|
||||||
|
ASSERT_EQ(1, outputs.size());
|
||||||
|
// The first output should be initialized and have the correct
|
||||||
|
// output.
|
||||||
|
auto mat = outputs[0].matrix<float>();
|
||||||
|
ASSERT_TRUE(outputs[0].IsInitialized());
|
||||||
|
EXPECT_FLOAT_EQ(5.0, mat(0, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RunHandlerTest, TestConcurrencyUseRunHandlerPool) {
|
||||||
|
Initialize({1, 2, 3, 4});
|
||||||
|
auto session = CreateSession();
|
||||||
|
ASSERT_TRUE(session != nullptr);
|
||||||
|
EXPECT_EQ(::tensorflow::Status::OK(), session->Create(def_));
|
||||||
|
|
||||||
|
RunOptions run_options;
|
||||||
|
run_options.mutable_experimental()->set_use_run_handler_pool(true);
|
||||||
|
|
||||||
|
// Fill in the input and ask for the output
|
||||||
|
thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
|
||||||
|
|
||||||
|
// Run the graph 1000 times in 4 different threads concurrently.
|
||||||
|
std::vector<string> output_names = {y_ + ":0"};
|
||||||
|
auto fn = [&session, output_names, run_options]() {
|
||||||
|
for (int i = 0; i < 1000; ++i) {
|
||||||
|
std::vector<std::pair<string, Tensor>> inputs;
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
// Run the graph
|
||||||
|
Status s = session->Run(run_options, inputs, output_names, {}, &outputs,
|
||||||
|
nullptr);
|
||||||
|
EXPECT_EQ(::tensorflow::Status::OK(), s);
|
||||||
|
ASSERT_EQ(1, outputs.size());
|
||||||
|
auto mat = outputs[0].matrix<float>();
|
||||||
|
EXPECT_FLOAT_EQ(3.0, mat(0, 0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
tp->Schedule(fn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the functions to finish.
|
||||||
|
delete tp;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/lib/strings/numbers.h"
|
#include "tensorflow/core/lib/strings/numbers.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/str_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -29,6 +30,54 @@ double ParamFromEnvWithDefault(const std::string& var_name,
|
|||||||
return (val && strings::safe_strtod(val, &num)) ? num : default_value;
|
return (val && strings::safe_strtod(val, &num)) ? num : default_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<double> ParamFromEnvWithDefault(const std::string& var_name,
|
||||||
|
std::vector<double> default_value) {
|
||||||
|
const char* val = std::getenv(var_name.c_str());
|
||||||
|
if (!val) {
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
std::vector<string> splits = str_util::Split(val, ",");
|
||||||
|
std::vector<double> result;
|
||||||
|
result.reserve(splits.size());
|
||||||
|
for (auto& split : splits) {
|
||||||
|
double num;
|
||||||
|
if (strings::safe_strtod(split, &num)) {
|
||||||
|
result.push_back(num);
|
||||||
|
} else {
|
||||||
|
LOG(ERROR) << "Wrong format for " << var_name << ". Use default value.";
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> ParamFromEnvWithDefault(const std::string& var_name,
|
||||||
|
std::vector<int> default_value) {
|
||||||
|
const char* val = std::getenv(var_name.c_str());
|
||||||
|
if (!val) {
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
std::vector<string> splits = str_util::Split(val, ",");
|
||||||
|
std::vector<int> result;
|
||||||
|
result.reserve(splits.size());
|
||||||
|
for (auto& split : splits) {
|
||||||
|
int num;
|
||||||
|
if (strings::safe_strto32(split, &num)) {
|
||||||
|
result.push_back(num);
|
||||||
|
} else {
|
||||||
|
LOG(ERROR) << "Wrong format for " << var_name << ". Use default value.";
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ParamFromEnvBoolWithDefault(const std::string& var_name,
|
||||||
|
bool default_value) {
|
||||||
|
const char* val = std::getenv(var_name.c_str());
|
||||||
|
return (val) ? str_util::Lowercase(val) == "true" : default_value;
|
||||||
|
}
|
||||||
|
|
||||||
void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
|
void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
|
||||||
int min_threads_per_request,
|
int min_threads_per_request,
|
||||||
std::vector<std::uint_fast32_t>* start_vec,
|
std::vector<std::uint_fast32_t>* start_vec,
|
||||||
|
@ -54,10 +54,27 @@ void ComputeInterOpStealingRanges(int num_threads, int min_threads_per_domain,
|
|||||||
std::vector<int> ChooseRequestsWithExponentialDistribution(
|
std::vector<int> ChooseRequestsWithExponentialDistribution(
|
||||||
int num_active_requests, int num_threads);
|
int num_active_requests, int num_threads);
|
||||||
|
|
||||||
// Loop environment variable named 'var_name' and return the value if it exist
|
// Look up environment variable named 'var_name' and return the value if it
|
||||||
// and can be parsed. Return 'default_value' otherwise.
|
// exist and can be parsed. Return 'default_value' otherwise.
|
||||||
double ParamFromEnvWithDefault(const std::string& var_name,
|
double ParamFromEnvWithDefault(const std::string& var_name,
|
||||||
double default_value);
|
double default_value);
|
||||||
|
|
||||||
|
// Look up environment variable named 'var_name' and return the value if it
|
||||||
|
// exist and can be parsed. The value must be in format val1,val2... Return
|
||||||
|
// 'default_value' otherwise.
|
||||||
|
std::vector<double> ParamFromEnvWithDefault(const std::string& var_name,
|
||||||
|
std::vector<double> default_value);
|
||||||
|
|
||||||
|
// Look up environment variable named 'var_name' and return the value if it
|
||||||
|
// exist and can be parsed. The value must be in format val1,val2... Return
|
||||||
|
// 'default_value' otherwise.
|
||||||
|
std::vector<int> ParamFromEnvWithDefault(const std::string& var_name,
|
||||||
|
std::vector<int> default_value);
|
||||||
|
|
||||||
|
// Look up environment variable named 'var_name' and return the value if it
|
||||||
|
// exist and can be parsed. Return 'default_value' otherwise.
|
||||||
|
bool ParamFromEnvBoolWithDefault(const std::string& var_name,
|
||||||
|
bool default_value);
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
|
#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
|
||||||
|
@ -124,5 +124,44 @@ TEST(RunHandlerUtilTest, TestExponentialRequestDistribution) {
|
|||||||
ASSERT_EQ(actual_distribution, expected_distribution);
|
ASSERT_EQ(actual_distribution, expected_distribution);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(RunHandlerUtilTest, TestParamFromEnvWithDefault) {
|
||||||
|
std::vector<double> result = ParamFromEnvWithDefault(
|
||||||
|
"RUN_HANDLER_TEST_ENV", std::vector<double>{0, 0, 0});
|
||||||
|
EXPECT_EQ(result.size(), 3);
|
||||||
|
EXPECT_EQ(result[0], 0);
|
||||||
|
EXPECT_EQ(result[1], 0);
|
||||||
|
EXPECT_EQ(result[2], 0);
|
||||||
|
|
||||||
|
std::vector<int> result2 = ParamFromEnvWithDefault("RUN_HANDLER_TEST_ENV",
|
||||||
|
std::vector<int>{0, 0, 0});
|
||||||
|
EXPECT_EQ(result2.size(), 3);
|
||||||
|
EXPECT_EQ(result2[0], 0);
|
||||||
|
EXPECT_EQ(result2[1], 0);
|
||||||
|
EXPECT_EQ(result2[2], 0);
|
||||||
|
|
||||||
|
bool result3 =
|
||||||
|
ParamFromEnvBoolWithDefault("RUN_HANDLER_TEST_ENV_BOOL", false);
|
||||||
|
EXPECT_EQ(result3, false);
|
||||||
|
|
||||||
|
// Set environment variable.
|
||||||
|
EXPECT_EQ(setenv("RUN_HANDLER_TEST_ENV", "1,2,3", true), 0);
|
||||||
|
result = ParamFromEnvWithDefault("RUN_HANDLER_TEST_ENV",
|
||||||
|
std::vector<double>{0, 0, 0});
|
||||||
|
EXPECT_EQ(result.size(), 3);
|
||||||
|
EXPECT_EQ(result[0], 1);
|
||||||
|
EXPECT_EQ(result[1], 2);
|
||||||
|
EXPECT_EQ(result[2], 3);
|
||||||
|
result2 = ParamFromEnvWithDefault("RUN_HANDLER_TEST_ENV",
|
||||||
|
std::vector<int>{0, 0, 0});
|
||||||
|
EXPECT_EQ(result.size(), 3);
|
||||||
|
EXPECT_EQ(result2[0], 1);
|
||||||
|
EXPECT_EQ(result2[1], 2);
|
||||||
|
EXPECT_EQ(result2[2], 3);
|
||||||
|
|
||||||
|
EXPECT_EQ(setenv("RUN_HANDLER_TEST_ENV_BOOL", "true", true), 0);
|
||||||
|
result3 = ParamFromEnvBoolWithDefault("RUN_HANDLER_TEST_ENV_BOOL", false);
|
||||||
|
EXPECT_EQ(result3, true);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user