Switch RunHandler to use an inference friendly thread pool.

Each inference has a dedicated work queue. All threads steal the work
in the priority order of the request (currently arrival time). Note that
there is one pool for both intra and inter work. However to avoid there are
some thread the are not allowed to steal inter work, which can be blocking.

PiperOrigin-RevId: 254257458
This commit is contained in:
A. Unique TensorFlower 2019-06-20 13:14:41 -07:00 committed by TensorFlower Gardener
parent 7ed84ad814
commit b71a4c0765
5 changed files with 528 additions and 96 deletions

View File

@ -4743,6 +4743,23 @@ tf_cc_test(
],
)
tf_cc_test(
name = "framework_run_handler_test",
size = "small",
srcs = ["framework/run_handler_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":framework_internal",
":lib",
":lib_internal",
":test",
":test_main",
"//third_party/eigen3",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
],
)
tf_cc_test(
name = "common_runtime_partitioning_utils_test",
size = "small",

View File

@ -66,6 +66,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/tracing.h"
@ -251,8 +252,38 @@ std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
static RunHandlerPool* GetOrCreateRunHandlerPool(
const SessionOptions& options) {
int num_inter_threads = 0;
int num_intra_threads = 0;
static const int env_num_inter_threads = NumInterOpThreadsFromEnvironment();
static const int env_num_intra_threads = NumIntraOpThreadsFromEnvironment();
if (env_num_inter_threads > 0) {
num_inter_threads = env_num_inter_threads;
}
if (env_num_intra_threads > 0) {
num_intra_threads = env_num_intra_threads;
}
if (num_inter_threads == 0) {
if (options.config.session_inter_op_thread_pool_size() > 0) {
// Note due to ShouldUseRunHandler we are guaranteed that
// run_options.inter_op_thread_pool() == 0
num_inter_threads =
options.config.session_inter_op_thread_pool(0).num_threads();
}
if (num_inter_threads == 0) {
num_inter_threads = NumInterOpThreadsFromSessionOptions(options);
}
}
if (num_intra_threads == 0) {
num_intra_threads = options.config.intra_op_parallelism_threads();
if (num_intra_threads == 0) {
num_intra_threads = port::NumSchedulableCPUs();
}
}
static RunHandlerPool* pool =
new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options));
new RunHandlerPool(num_inter_threads, num_intra_threads);
return pool;
}
@ -630,7 +661,7 @@ Status DirectSession::RunInternal(
if (ShouldUseRunHandlerPool(run_options) &&
run_options.experimental().use_run_handler_pool()) {
VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
handler = GetOrCreateRunHandlerPool(options_)->Get();
handler = GetOrCreateRunHandlerPool(options_)->Get(step_id);
}
auto* handler_ptr = handler.get();
@ -663,6 +694,10 @@ Status DirectSession::RunInternal(
device_thread_pool->Schedule(std::move(c));
};
}
if (handler != nullptr) {
args.user_intra_op_threadpool = handler->AsIntraThreadPoolInterface();
}
item.executor->RunAsync(args, barrier->Get());
}

View File

@ -19,62 +19,331 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/run_handler_util.h"
#include "tensorflow/core/lib/core/threadpool_interface.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/numa.h"
#include "tensorflow/core/platform/setround.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
static constexpr int32 kMaxConcurrentHandlers = 128;
// TODO(azaks): Refactor with thread:ThreadPool
class RunHandlerEnvironment {
typedef Thread EnvThread;
struct TaskImpl {
std::function<void()> f;
Context context;
uint64 trace_id;
};
Env* const env_;
const ThreadOptions thread_options_;
const string name_;
public:
struct Task {
std::unique_ptr<TaskImpl> f;
};
RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options,
const string& name)
: env_(env), thread_options_(thread_options), name_(name) {}
EnvThread* CreateThread(std::function<void()> f) {
return env_->StartThread(thread_options_, name_, [=]() {
// Set the processor flag to flush denormals to zero.
port::ScopedFlushDenormal flush;
// Set the processor rounding mode to ROUND TO NEAREST.
port::ScopedSetRound round(FE_TONEAREST);
if (thread_options_.numa_node != port::kNUMANoAffinity) {
port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
}
f();
});
}
Task CreateTask(std::function<void()> f) {
uint64 id = 0;
if (tracing::EventCollector::IsEnabled()) {
id = tracing::GetUniqueArg();
tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
}
return Task{
std::unique_ptr<TaskImpl>(new TaskImpl{
std::move(f),
Context(ContextKind::kThread),
id,
}),
};
}
void ExecuteTask(const Task& t) {
WithContext wc(t.f->context);
tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
t.f->trace_id);
t.f->f();
}
};
class RunHandlerThreadPool {
public:
typedef typename RunHandlerEnvironment::Task Task;
typedef Eigen::RunQueue<Task, 1024> Queue;
struct PerThread {
constexpr PerThread() : pool(nullptr), thread_id(-1) {}
RunHandlerThreadPool* pool; // Parent pool, or null for normal threads.
int thread_id; // Worker thread index in pool.
};
RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
Env* env, const ThreadOptions& thread_options,
const string& name)
: num_threads_(num_blocking_threads + num_non_blocking_threads),
num_blocking_threads_(num_blocking_threads),
num_non_blocking_threads_(num_non_blocking_threads),
thread_data_(num_threads_),
env_(env, thread_options, name),
name_(name) {
VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
<< num_blocking_threads_ << " blocking threads and "
<< num_non_blocking_threads_ << " non-blocking threads.";
cancelled_ = false;
thread_data_.resize(num_threads_);
for (int i = 0; i < num_threads_; i++) {
thread_data_[i].thread.reset(
env_.CreateThread([this, i, num_blocking_threads]() {
WorkerLoop(i, i < num_blocking_threads);
}));
}
}
~RunHandlerThreadPool() {
VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
cancelled_ = true;
for (size_t i = 0; i < thread_data_.size(); ++i) {
thread_data_[i].thread.reset();
}
}
struct ThreadWorkSource {
ThreadWorkSource()
: blocking_inflight(0), non_blocking_inflight(0), traceme_id(0) {}
Queue blocking_work_queue;
std::atomic<int64> blocking_inflight;
mutex blocking_mu;
Queue non_blocking_work_queue;
std::atomic<int64> non_blocking_inflight;
mutex non_blocking_mu;
std::atomic<int64> traceme_id;
};
void AddWorkToQueue(Queue* q, mutex* mu, bool inter_work,
std::atomic<int64>* traceme_id,
std::function<void()> fn) {
Task t = env_.CreateTask(std::move(fn));
{
mutex_lock l(*mu);
// For a given queue, only one thread can call PushFront.
t = q->PushFront(std::move(t));
VLOG(3) << "Added " << (inter_work ? "inter" : "intra") << " work from "
<< traceme_id->load(std::memory_order_relaxed);
}
if (t.f) {
VLOG(3) << "Running " << (inter_work ? "inter" : "intra") << " work from "
<< traceme_id->load(std::memory_order_relaxed);
env_.ExecuteTask(t);
}
}
// Set work queues from which the thread 'tid' can steal its work.
void SetThreadWorkSources(
int tid,
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
mutex_lock l(thread_data_[tid].mu);
thread_data_[tid].thread_work_sources.resize(0);
for (int i = 0; i < thread_work_sources.size(); ++i) {
thread_data_[tid].thread_work_sources.emplace_back(
thread_work_sources[i]);
}
}
PerThread* GetPerThread() {
thread_local PerThread per_thread_;
PerThread* pt = &per_thread_;
return pt;
}
int CurrentThreadId() const {
const PerThread* pt =
const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
if (pt->pool == this) {
return pt->thread_id;
} else {
return -1;
}
}
int NumThreads() const { return num_threads_; }
int NumBlockingThreads() const { return num_blocking_threads_; }
int NumNonBlockingThreads() const { return num_non_blocking_threads_; }
void WorkerLoop(int thread_id, bool may_steal_blocking_work);
private:
struct ThreadData {
ThreadData() : thread_work_sources(kMaxConcurrentHandlers) {}
mutex mu;
std::unique_ptr<Thread> thread;
Eigen::MaxSizeVector<ThreadWorkSource*> thread_work_sources GUARDED_BY(mu);
};
const int num_threads_;
const int num_blocking_threads_;
const int num_non_blocking_threads_;
Eigen::MaxSizeVector<ThreadData> thread_data_;
RunHandlerEnvironment env_;
std::atomic<bool> cancelled_;
string name_;
};
// Main worker thread loop.
void RunHandlerThreadPool::WorkerLoop(int thread_id,
bool may_steal_blocking_work) {
PerThread* pt = GetPerThread();
pt->pool = this;
pt->thread_id = thread_id;
while (!cancelled_) {
Task t;
bool inter_work = true;
std::atomic<int64>* inflight_counter = nullptr;
int64 traceme_id = 0;
Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
&thread_data_[thread_id].thread_work_sources;
{
// The mutex is not hot since its per thread and can only be held
// by some other thread when a session run starts/finishes.
mutex_lock l(thread_data_[thread_id].mu);
for (int i = 0; i < thread_work_sources->size(); ++i) {
ThreadWorkSource* tws = (*thread_work_sources)[i];
// We want a smallish numbers of inter threads since
// otherwise there will be contention in PropagateOutputs.
// This is best effort policy.
static constexpr int32 kMaxBlockingInflight = 10;
if (may_steal_blocking_work &&
(tws->blocking_inflight.load(std::memory_order_relaxed) <
kMaxBlockingInflight)) {
t = tws->blocking_work_queue.PopBack();
if (t.f) {
inflight_counter = &(tws->blocking_inflight);
traceme_id = tws->traceme_id.load(std::memory_order_relaxed);
break;
}
}
t = tws->non_blocking_work_queue.PopBack();
if (t.f) {
inflight_counter = &(tws->non_blocking_inflight);
traceme_id = tws->traceme_id.load(std::memory_order_relaxed);
inter_work = false;
break;
}
}
}
if (t.f) {
profiler::TraceMe activity(
[=] {
return strings::StrCat(inter_work ? "inter" : "intra", " ",
"#id = ", traceme_id, " ", thread_id, "#");
},
profiler::TraceMeLevel::kInfo);
VLOG(2) << "Running " << (inter_work ? "inter" : "intra") << " work from "
<< traceme_id;
inflight_counter->fetch_add(1, std::memory_order_relaxed);
env_.ExecuteTask(t);
inflight_counter->fetch_sub(1, std::memory_order_relaxed);
} else {
profiler::TraceMe activity(
[=] {
return strings::StrCat("Sleeping#thread_id=", thread_id, "#");
},
profiler::TraceMeLevel::kInfo);
if (VLOG_IS_ON(4)) {
mutex_lock l(thread_data_[thread_id].mu);
for (int i = 0; i < thread_work_sources->size(); ++i) {
ThreadWorkSource* tws = (*thread_work_sources)[i];
VLOG(4) << "source id " << i << " traceme_id = "
<< tws->traceme_id.load(std::memory_order_relaxed)
<< " inter queue size " << tws->blocking_work_queue.Size()
<< " inter inflight "
<< tws->blocking_inflight.load(std::memory_order_relaxed)
<< " intra queue size " << tws->non_blocking_work_queue.Size()
<< " intra inflight "
<< tws->non_blocking_inflight.load(std::memory_order_relaxed);
}
}
Env::Default()->SleepForMicroseconds(250);
}
}
}
} // namespace
// Contains the concrete implementation of the RunHandler.
// Externally visible RunHandler class simply forwards the work to this one.
class RunHandler::Impl {
public:
explicit Impl(RunHandlerPool::Impl* pool_impl) : pool_impl_(pool_impl) {
Reset();
}
explicit Impl(RunHandlerPool::Impl* pool_impl);
~Impl() {}
void set_inter_op_scheduling_range(std::uint_fast32_t start,
std::uint_fast32_t limit) {
inter_op_scheduling_range_.store(EncodePartition(start, limit),
std::memory_order_release);
}
std::uint_fast32_t inter_op_scheduling_range() const {
return inter_op_scheduling_range_.load(std::memory_order_acquire);
thread::ThreadPoolInterface* thread_pool_interface() {
return thread_pool_interface_.get();
}
// Stores now time (in microseconds) since unix epoch when the handler is
// requested via RunHandlerPool::Get().
uint64 start_time_us() const { return start_time_us_; }
int64 step_id() const { return step_id_; }
void ScheduleInterOpClosure(std::function<void()> fn);
void ScheduleIntraOpClosure(std::function<void()> fn);
void Reset();
void Reset(int64 step_id);
RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
RunHandlerThreadPool::ThreadWorkSource* tws() { return &tws_; }
private:
// Encoding/decoding logic for storing [start, limit) into a single
// uint_fast32_t int. We assume that pool_num_threads < (1 << 16).
const int kMaxPartitionBits = 16;
const int kMaxThreads = 1 << kMaxPartitionBits;
class ThreadPoolInterfaceWrapper : public thread::ThreadPoolInterface {
public:
explicit ThreadPoolInterfaceWrapper(Impl* run_handler_impl)
: run_handler_impl_(run_handler_impl) {}
~ThreadPoolInterfaceWrapper() override {}
void Schedule(std::function<void()> fn) override;
int NumThreads() const override;
int CurrentThreadId() const override;
std::uint_fast32_t EncodePartition(std::uint_fast32_t start,
std::uint_fast32_t limit) {
return (start << kMaxPartitionBits) | limit;
}
private:
RunHandler::Impl* run_handler_impl_ = nullptr;
};
void DecodePartition(std::uint_fast32_t val, std::uint_fast32_t* start,
std::uint_fast32_t* limit) {
*limit = val & (kMaxThreads - 1);
val >>= kMaxPartitionBits;
*start = val;
}
std::atomic_uint_fast32_t inter_op_scheduling_range_;
RunHandlerPool::Impl* pool_impl_; // NOT OWNED.
uint64 start_time_us_;
int64 step_id_;
std::unique_ptr<thread::ThreadPoolInterface> thread_pool_interface_;
RunHandlerThreadPool::ThreadWorkSource tws_;
};
// Contains shared state across all run handlers present in the pool. Also
@ -82,30 +351,17 @@ class RunHandler::Impl {
// This class is thread safe.
class RunHandlerPool::Impl {
public:
explicit Impl(int num_inter_op_threads)
: max_handlers_(128),
inter_op_thread_pool_(new thread::ThreadPool(
Env::Default(), ThreadOptions(), "inter_op", num_inter_op_threads)),
explicit Impl(int num_inter_op_threads, int num_intra_op_threads)
: max_handlers_(kMaxConcurrentHandlers),
run_handler_thread_pool_(new RunHandlerThreadPool(
num_inter_op_threads, num_intra_op_threads, Env::Default(),
ThreadOptions(), "tf_run_handler_pool")),
iterations_(0) {
VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
for (int i = 0; i < max_handlers_; ++i) {
handlers_.emplace_back(new RunHandler::Impl(this));
free_handlers_.push_back(handlers_.back().get());
}
std::vector<std::pair<unsigned, unsigned>> steal_partitions(
num_inter_op_threads);
std::vector<std::uint_fast32_t> start_vec(num_inter_op_threads);
std::vector<std::uint_fast32_t> end_vec(num_inter_op_threads);
ComputeInterOpStealingRanges(num_inter_op_threads, kMinThreadsPerDomain,
&start_vec, &end_vec);
for (int i = 0; i < num_inter_op_threads; ++i) {
steal_partitions[i] = std::make_pair(start_vec[i], end_vec[i]);
VLOG(1) << "Steal partition i: " << i << " steal_start: " << start_vec[i]
<< " steal_end: " << end_vec[i];
}
inter_op_thread_pool_->SetStealPartitions(steal_partitions);
}
~Impl() {
@ -116,11 +372,11 @@ class RunHandlerPool::Impl {
DCHECK_EQ(sorted_active_handlers_.size(), 0);
}
thread::ThreadPool* inter_op_thread_pool() const {
return inter_op_thread_pool_.get();
RunHandlerThreadPool* run_handler_thread_pool() {
return run_handler_thread_pool_.get();
}
std::unique_ptr<RunHandler> Get() LOCKS_EXCLUDED(mu_) {
std::unique_ptr<RunHandler> Get(int64 step_id) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
while (free_handlers_.empty()) {
one_handler_free_.wait(l);
@ -128,7 +384,7 @@ class RunHandlerPool::Impl {
// Remove the last entry from free_handlers_ and add to the end of
// sorted_active_handlers_.
auto* handler_impl = free_handlers_.back();
handler_impl->Reset();
handler_impl->Reset(step_id);
// Sortedness isn't violated if we simply add at the end of the list, since
// handlers are expected to be obtained in increasing order of time.
sorted_active_handlers_.push_back(handler_impl);
@ -144,6 +400,9 @@ class RunHandlerPool::Impl {
mutex_lock l(mu_);
DCHECK_GT(sorted_active_handlers_.size(), 0);
CHECK_EQ(handler->tws()->blocking_work_queue.Size(), 0);
CHECK_EQ(handler->tws()->non_blocking_work_queue.Size(), 0);
uint64 now = tensorflow::Env::Default()->NowMicros();
double elapsed = (now - handler->start_time_us()) / 1000.0;
time_hist_.Add(elapsed);
@ -176,26 +435,16 @@ class RunHandlerPool::Impl {
// inference).
const int max_handlers_;
// Minimum number of threads allocated to process a request.
const int kMinThreadsPerRequest = 3;
// Minmum number of threads in a steal domain. Each thread will first try
// to steal from threads in the same domain before stealing from threads
// in different domains.
const int kMinThreadsPerDomain = 2 * kMinThreadsPerRequest;
// Thread safe part.
const std::unique_ptr<thread::ThreadPool> inter_op_thread_pool_;
std::unique_ptr<RunHandlerThreadPool> run_handler_thread_pool_;
// Thread compatible part used only by lock under RunHandlerPool.
// Handlers are sorted by start time.
// TODO(azaks): sort by the remaining latency budget.
std::vector<RunHandler::Impl*> sorted_active_handlers_ GUARDED_BY(mu_);
std::vector<RunHandler::Impl*> free_handlers_ GUARDED_BY(mu_);
std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ GUARDED_BY(mu_);
// Histogram of elapsed runtime of every handler (in ms).
histogram::Histogram time_hist_ GUARDED_BY(mu_);
std::vector<std::uint_fast32_t> inter_op_start_ GUARDED_BY(mu_);
std::vector<std::uint_fast32_t> inter_op_limit_ GUARDED_BY(mu_);
int64 iterations_ GUARDED_BY(mu_);
condition_variable one_handler_free_;
mutex mu_;
@ -204,63 +453,98 @@ class RunHandlerPool::Impl {
void RunHandlerPool::Impl::RecomputePoolStatsLocked() {
int num_active_requests = sorted_active_handlers_.size();
if (num_active_requests == 0) return;
Eigen::MaxSizeVector<RunHandlerThreadPool::ThreadWorkSource*>
thread_work_sources(num_active_requests);
int num_threads = inter_op_thread_pool_->NumThreads();
inter_op_start_.resize(num_active_requests);
inter_op_limit_.resize(num_active_requests);
ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
kMinThreadsPerRequest, &inter_op_start_,
&inter_op_limit_);
thread_work_sources.resize(num_active_requests);
for (int i = 0; i < num_active_requests; ++i) {
sorted_active_handlers_[i]->set_inter_op_scheduling_range(
inter_op_start_[i], inter_op_limit_[i]);
thread_work_sources[i] = sorted_active_handlers_[i]->tws();
}
for (int i = 0; i < run_handler_thread_pool()->NumThreads(); ++i) {
VLOG(2) << "Setting work for tid = " << i;
run_handler_thread_pool()->SetThreadWorkSources(i, thread_work_sources);
}
if (iterations_++ % 5000 == 0 && VLOG_IS_ON(1)) {
if (iterations_++ % 50000 == 10 && VLOG_IS_ON(1)) {
VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
VLOG(1) << "Active session runs: " << num_active_requests;
uint64 now = tensorflow::Env::Default()->NowMicros();
string ranges_str = "";
string times_str = "";
string ids_str = "";
for (int i = 0; i < num_active_requests; ++i) {
if (i > 0) {
times_str += " ";
ranges_str += " ";
ids_str += " ";
}
times_str += strings::StrCat(
(now - sorted_active_handlers_[i]->start_time_us()) / 1000.0, " ms.");
ranges_str += strings::StrCat("[", inter_op_start_[i], ", ",
inter_op_limit_[i], ")");
ids_str +=
strings::StrCat(sorted_active_handlers_[i]->tws()->traceme_id.load(
std::memory_order_relaxed));
}
VLOG(1) << "Elapsed times are: " << times_str;
VLOG(1) << "Ranges are: " << ranges_str;
VLOG(1) << "Step ids are: " << ids_str;
}
}
void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
std::uint_fast32_t start = 0, limit = 0;
DecodePartition(inter_op_scheduling_range(), &start, &limit);
DCHECK_LT(start, limit);
pool_impl_->inter_op_thread_pool()->ScheduleWithHint(std::move(fn), start,
limit);
// It is important to return a value such as:
// CurrentThreadId() in [0, NumThreads)
int RunHandler::Impl::ThreadPoolInterfaceWrapper::NumThreads() const {
return run_handler_impl_->pool_impl_->run_handler_thread_pool()->NumThreads();
}
void RunHandler::Impl::Reset() {
set_inter_op_scheduling_range(
0, pool_impl_->inter_op_thread_pool()->NumThreads());
int RunHandler::Impl::ThreadPoolInterfaceWrapper::CurrentThreadId() const {
return run_handler_impl_->pool_impl_->run_handler_thread_pool()
->CurrentThreadId();
}
void RunHandler::Impl::ThreadPoolInterfaceWrapper::Schedule(
std::function<void()> fn) {
return run_handler_impl_->ScheduleIntraOpClosure(std::move(fn));
}
RunHandler::Impl::Impl(RunHandlerPool::Impl* pool_impl)
: pool_impl_(pool_impl) {
thread_pool_interface_.reset(new ThreadPoolInterfaceWrapper(this));
Reset(0);
}
void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
VLOG(3) << "Scheduling inter work for "
<< tws()->traceme_id.load(std::memory_order_relaxed);
pool_impl_->run_handler_thread_pool()->AddWorkToQueue(
&tws()->blocking_work_queue, &tws()->blocking_mu, true,
&tws()->traceme_id, std::move(fn));
}
void RunHandler::Impl::ScheduleIntraOpClosure(std::function<void()> fn) {
VLOG(3) << "Scheduling inter work for "
<< tws()->traceme_id.load(std::memory_order_relaxed);
pool_impl_->run_handler_thread_pool()->AddWorkToQueue(
&tws()->non_blocking_work_queue, &tws()->non_blocking_mu, false,
&tws()->traceme_id, std::move(fn));
}
void RunHandler::Impl::Reset(int64 step_id) {
start_time_us_ = tensorflow::Env::Default()->NowMicros();
step_id_ = step_id;
tws_.traceme_id = step_id;
}
RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
: impl_(new Impl(num_inter_op_threads)) {}
: impl_(new Impl(num_inter_op_threads, 0)) {}
RunHandlerPool::RunHandlerPool(int num_inter_op_threads,
int num_intra_op_threads)
: impl_(new Impl(num_inter_op_threads, num_intra_op_threads)) {}
RunHandlerPool::~RunHandlerPool() {}
std::unique_ptr<RunHandler> RunHandlerPool::Get() { return impl_->Get(); }
std::unique_ptr<RunHandler> RunHandlerPool::Get(int64 step_id) {
return impl_->Get(step_id);
}
RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
@ -268,5 +552,10 @@ void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
impl_->ScheduleInterOpClosure(std::move(fn));
}
thread::ThreadPoolInterface* RunHandler::AsIntraThreadPoolInterface() {
return impl_->thread_pool_interface();
}
RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
} // namespace tensorflow

View File

@ -22,6 +22,10 @@ limitations under the License.
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace Eigen {
struct ThreadPoolDevice;
}
namespace tensorflow {
class RunHandler;
@ -46,6 +50,8 @@ class RunHandler;
class RunHandlerPool {
public:
explicit RunHandlerPool(int num_inter_op_threads);
RunHandlerPool(int num_inter_op_threads, int num_intra_op_threads);
~RunHandlerPool();
// Returns an inactive RunHandler from the pool.
@ -56,7 +62,7 @@ class RunHandlerPool {
// unique_ptr is destroyed.
//
// Will block unless there is an inactive handler.
std::unique_ptr<RunHandler> Get();
std::unique_ptr<RunHandler> Get(int64 step_id = 0);
private:
class Impl;
@ -65,8 +71,10 @@ class RunHandlerPool {
std::unique_ptr<Impl> impl_;
};
// RunHandler can be used to schedule inter-op closures to run on a global pool
// shared across all Session::Run(s).
// RunHandler can be used to schedule inter/intra-op closures to run on a global
// pool shared across all Session::Run(s). The closures are enqueued to a
// handler specific queue, from which the work is stolen in a priority order
// (time of the Get() call).
//
// It can only be created via RunHandlerPool::Get().
//
@ -78,6 +86,7 @@ class RunHandlerPool {
class RunHandler {
public:
void ScheduleInterOpClosure(std::function<void()> fn);
thread::ThreadPoolInterface* AsIntraThreadPoolInterface();
~RunHandler();

View File

@ -0,0 +1,82 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/run_handler.h"
#include <memory>
#include <vector>
#define EIGEN_USE_THREADS
#include "absl/memory/memory.h"
#include "absl/synchronization/barrier.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
TEST(RunHandlerUtilTest, TestBasicScheduling) {
int num_threads = 2;
int num_handlers = 10;
std::unique_ptr<RunHandlerPool> pool(new RunHandlerPool(num_threads));
// RunHandler has 2 * num_threads (inter + intra) -
// all should be able to run concurrently.
absl::Barrier barrier1(num_threads);
absl::Barrier barrier2(num_threads);
BlockingCounter counter(2 * num_handlers * num_threads);
int num_test_threads = 10;
thread::ThreadPool test_pool(Env::Default(), "test", num_test_threads);
for (int i = 0; i < 10; ++i) {
test_pool.Schedule([&counter, &barrier1, &barrier2, &pool, i,
num_threads]() {
auto handler = pool->Get();
BlockingCounter local_counter(2 * num_threads);
auto intra_thread_pool = handler->AsIntraThreadPoolInterface();
for (int j = 0; j < num_threads; ++j) {
handler->ScheduleInterOpClosure(
[&local_counter, &counter, &barrier1, i]() {
if (i == 2) {
barrier1.Block();
}
counter.DecrementCount();
local_counter.DecrementCount();
});
intra_thread_pool->Schedule([&local_counter, &counter, &barrier2, i]() {
if (i == 9) {
barrier2.Block();
}
counter.DecrementCount();
local_counter.DecrementCount();
});
}
local_counter.Wait();
});
}
counter.Wait();
}
} // namespace
} // namespace tensorflow