diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 0e2ba3f5d4b..3450b9f0f04 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -4743,6 +4743,23 @@ tf_cc_test( ], ) +tf_cc_test( + name = "framework_run_handler_test", + size = "small", + srcs = ["framework/run_handler_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":framework_internal", + ":lib", + ":lib_internal", + ":test", + ":test_main", + "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + ], +) + tf_cc_test( name = "common_runtime_partitioning_utils_test", size = "small", diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index f8e54bed1b1..ba4919e737c 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -66,6 +66,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/byte_order.h" +#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/tracing.h" @@ -251,8 +252,38 @@ std::atomic_int_fast64_t DirectSession::step_id_counter_(1); static RunHandlerPool* GetOrCreateRunHandlerPool( const SessionOptions& options) { + int num_inter_threads = 0; + int num_intra_threads = 0; + static const int env_num_inter_threads = NumInterOpThreadsFromEnvironment(); + static const int env_num_intra_threads = NumIntraOpThreadsFromEnvironment(); + if (env_num_inter_threads > 0) { + num_inter_threads = env_num_inter_threads; + } + if (env_num_intra_threads > 0) { + num_intra_threads = env_num_intra_threads; + } + + if (num_inter_threads == 0) { + if (options.config.session_inter_op_thread_pool_size() > 0) { + // Note due to ShouldUseRunHandler we are guaranteed that + // run_options.inter_op_thread_pool() == 0 + num_inter_threads = + options.config.session_inter_op_thread_pool(0).num_threads(); + } + if (num_inter_threads == 0) { + num_inter_threads = NumInterOpThreadsFromSessionOptions(options); + } + } + + if (num_intra_threads == 0) { + num_intra_threads = options.config.intra_op_parallelism_threads(); + if (num_intra_threads == 0) { + num_intra_threads = port::NumSchedulableCPUs(); + } + } + static RunHandlerPool* pool = - new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options)); + new RunHandlerPool(num_inter_threads, num_intra_threads); return pool; } @@ -630,7 +661,7 @@ Status DirectSession::RunInternal( if (ShouldUseRunHandlerPool(run_options) && run_options.experimental().use_run_handler_pool()) { VLOG(1) << "Using RunHandler to scheduler inter-op closures."; - handler = GetOrCreateRunHandlerPool(options_)->Get(); + handler = GetOrCreateRunHandlerPool(options_)->Get(step_id); } auto* handler_ptr = handler.get(); @@ -663,6 +694,10 @@ Status DirectSession::RunInternal( device_thread_pool->Schedule(std::move(c)); }; } + if (handler != nullptr) { + args.user_intra_op_threadpool = handler->AsIntraThreadPoolInterface(); + } + item.executor->RunAsync(args, barrier->Get()); } diff --git a/tensorflow/core/framework/run_handler.cc b/tensorflow/core/framework/run_handler.cc index fa7d46cca56..14a712cc653 100644 --- a/tensorflow/core/framework/run_handler.cc +++ b/tensorflow/core/framework/run_handler.cc @@ -19,62 +19,331 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/run_handler_util.h" +#include "tensorflow/core/lib/core/threadpool_interface.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/context.h" +#include "tensorflow/core/platform/denormal.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/numa.h" +#include "tensorflow/core/platform/setround.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { +namespace { +static constexpr int32 kMaxConcurrentHandlers = 128; + +// TODO(azaks): Refactor with thread:ThreadPool +class RunHandlerEnvironment { + typedef Thread EnvThread; + struct TaskImpl { + std::function f; + Context context; + uint64 trace_id; + }; + Env* const env_; + const ThreadOptions thread_options_; + const string name_; + + public: + struct Task { + std::unique_ptr f; + }; + + RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options, + const string& name) + : env_(env), thread_options_(thread_options), name_(name) {} + + EnvThread* CreateThread(std::function f) { + return env_->StartThread(thread_options_, name_, [=]() { + // Set the processor flag to flush denormals to zero. + port::ScopedFlushDenormal flush; + // Set the processor rounding mode to ROUND TO NEAREST. + port::ScopedSetRound round(FE_TONEAREST); + if (thread_options_.numa_node != port::kNUMANoAffinity) { + port::NUMASetThreadNodeAffinity(thread_options_.numa_node); + } + f(); + }); + } + + Task CreateTask(std::function f) { + uint64 id = 0; + if (tracing::EventCollector::IsEnabled()) { + id = tracing::GetUniqueArg(); + tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id); + } + return Task{ + std::unique_ptr(new TaskImpl{ + std::move(f), + Context(ContextKind::kThread), + id, + }), + }; + } + + void ExecuteTask(const Task& t) { + WithContext wc(t.f->context); + tracing::ScopedRegion region(tracing::EventCategory::kRunClosure, + t.f->trace_id); + t.f->f(); + } +}; + +class RunHandlerThreadPool { + public: + typedef typename RunHandlerEnvironment::Task Task; + typedef Eigen::RunQueue Queue; + + struct PerThread { + constexpr PerThread() : pool(nullptr), thread_id(-1) {} + RunHandlerThreadPool* pool; // Parent pool, or null for normal threads. + int thread_id; // Worker thread index in pool. + }; + + RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads, + Env* env, const ThreadOptions& thread_options, + const string& name) + : num_threads_(num_blocking_threads + num_non_blocking_threads), + num_blocking_threads_(num_blocking_threads), + num_non_blocking_threads_(num_non_blocking_threads), + thread_data_(num_threads_), + env_(env, thread_options, name), + name_(name) { + VLOG(1) << "Creating RunHandlerThreadPool " << name << " with " + << num_blocking_threads_ << " blocking threads and " + << num_non_blocking_threads_ << " non-blocking threads."; + cancelled_ = false; + + thread_data_.resize(num_threads_); + for (int i = 0; i < num_threads_; i++) { + thread_data_[i].thread.reset( + env_.CreateThread([this, i, num_blocking_threads]() { + WorkerLoop(i, i < num_blocking_threads); + })); + } + } + + ~RunHandlerThreadPool() { + VLOG(1) << "Exiting RunHandlerThreadPool " << name_; + + cancelled_ = true; + for (size_t i = 0; i < thread_data_.size(); ++i) { + thread_data_[i].thread.reset(); + } + } + + struct ThreadWorkSource { + ThreadWorkSource() + : blocking_inflight(0), non_blocking_inflight(0), traceme_id(0) {} + Queue blocking_work_queue; + std::atomic blocking_inflight; + mutex blocking_mu; + Queue non_blocking_work_queue; + std::atomic non_blocking_inflight; + mutex non_blocking_mu; + std::atomic traceme_id; + }; + + void AddWorkToQueue(Queue* q, mutex* mu, bool inter_work, + std::atomic* traceme_id, + std::function fn) { + Task t = env_.CreateTask(std::move(fn)); + { + mutex_lock l(*mu); + // For a given queue, only one thread can call PushFront. + t = q->PushFront(std::move(t)); + VLOG(3) << "Added " << (inter_work ? "inter" : "intra") << " work from " + << traceme_id->load(std::memory_order_relaxed); + } + if (t.f) { + VLOG(3) << "Running " << (inter_work ? "inter" : "intra") << " work from " + << traceme_id->load(std::memory_order_relaxed); + env_.ExecuteTask(t); + } + } + + // Set work queues from which the thread 'tid' can steal its work. + void SetThreadWorkSources( + int tid, + const Eigen::MaxSizeVector& thread_work_sources) { + mutex_lock l(thread_data_[tid].mu); + thread_data_[tid].thread_work_sources.resize(0); + for (int i = 0; i < thread_work_sources.size(); ++i) { + thread_data_[tid].thread_work_sources.emplace_back( + thread_work_sources[i]); + } + } + + PerThread* GetPerThread() { + thread_local PerThread per_thread_; + PerThread* pt = &per_thread_; + return pt; + } + + int CurrentThreadId() const { + const PerThread* pt = + const_cast(this)->GetPerThread(); + if (pt->pool == this) { + return pt->thread_id; + } else { + return -1; + } + } + + int NumThreads() const { return num_threads_; } + + int NumBlockingThreads() const { return num_blocking_threads_; } + + int NumNonBlockingThreads() const { return num_non_blocking_threads_; } + + void WorkerLoop(int thread_id, bool may_steal_blocking_work); + + private: + struct ThreadData { + ThreadData() : thread_work_sources(kMaxConcurrentHandlers) {} + mutex mu; + std::unique_ptr thread; + Eigen::MaxSizeVector thread_work_sources GUARDED_BY(mu); + }; + + const int num_threads_; + const int num_blocking_threads_; + const int num_non_blocking_threads_; + Eigen::MaxSizeVector thread_data_; + RunHandlerEnvironment env_; + std::atomic cancelled_; + string name_; +}; + +// Main worker thread loop. +void RunHandlerThreadPool::WorkerLoop(int thread_id, + bool may_steal_blocking_work) { + PerThread* pt = GetPerThread(); + pt->pool = this; + pt->thread_id = thread_id; + + while (!cancelled_) { + Task t; + bool inter_work = true; + std::atomic* inflight_counter = nullptr; + int64 traceme_id = 0; + Eigen::MaxSizeVector* thread_work_sources = + &thread_data_[thread_id].thread_work_sources; + { + // The mutex is not hot since its per thread and can only be held + // by some other thread when a session run starts/finishes. + mutex_lock l(thread_data_[thread_id].mu); + + for (int i = 0; i < thread_work_sources->size(); ++i) { + ThreadWorkSource* tws = (*thread_work_sources)[i]; + // We want a smallish numbers of inter threads since + // otherwise there will be contention in PropagateOutputs. + // This is best effort policy. + static constexpr int32 kMaxBlockingInflight = 10; + if (may_steal_blocking_work && + (tws->blocking_inflight.load(std::memory_order_relaxed) < + kMaxBlockingInflight)) { + t = tws->blocking_work_queue.PopBack(); + if (t.f) { + inflight_counter = &(tws->blocking_inflight); + traceme_id = tws->traceme_id.load(std::memory_order_relaxed); + break; + } + } + t = tws->non_blocking_work_queue.PopBack(); + if (t.f) { + inflight_counter = &(tws->non_blocking_inflight); + traceme_id = tws->traceme_id.load(std::memory_order_relaxed); + inter_work = false; + break; + } + } + } + if (t.f) { + profiler::TraceMe activity( + [=] { + return strings::StrCat(inter_work ? "inter" : "intra", " ", + "#id = ", traceme_id, " ", thread_id, "#"); + }, + profiler::TraceMeLevel::kInfo); + VLOG(2) << "Running " << (inter_work ? "inter" : "intra") << " work from " + << traceme_id; + inflight_counter->fetch_add(1, std::memory_order_relaxed); + env_.ExecuteTask(t); + inflight_counter->fetch_sub(1, std::memory_order_relaxed); + } else { + profiler::TraceMe activity( + [=] { + return strings::StrCat("Sleeping#thread_id=", thread_id, "#"); + }, + profiler::TraceMeLevel::kInfo); + if (VLOG_IS_ON(4)) { + mutex_lock l(thread_data_[thread_id].mu); + for (int i = 0; i < thread_work_sources->size(); ++i) { + ThreadWorkSource* tws = (*thread_work_sources)[i]; + VLOG(4) << "source id " << i << " traceme_id = " + << tws->traceme_id.load(std::memory_order_relaxed) + << " inter queue size " << tws->blocking_work_queue.Size() + << " inter inflight " + << tws->blocking_inflight.load(std::memory_order_relaxed) + << " intra queue size " << tws->non_blocking_work_queue.Size() + << " intra inflight " + << tws->non_blocking_inflight.load(std::memory_order_relaxed); + } + } + Env::Default()->SleepForMicroseconds(250); + } + } +} + +} // namespace // Contains the concrete implementation of the RunHandler. // Externally visible RunHandler class simply forwards the work to this one. class RunHandler::Impl { public: - explicit Impl(RunHandlerPool::Impl* pool_impl) : pool_impl_(pool_impl) { - Reset(); - } + explicit Impl(RunHandlerPool::Impl* pool_impl); ~Impl() {} - void set_inter_op_scheduling_range(std::uint_fast32_t start, - std::uint_fast32_t limit) { - inter_op_scheduling_range_.store(EncodePartition(start, limit), - std::memory_order_release); - } - - std::uint_fast32_t inter_op_scheduling_range() const { - return inter_op_scheduling_range_.load(std::memory_order_acquire); + thread::ThreadPoolInterface* thread_pool_interface() { + return thread_pool_interface_.get(); } // Stores now time (in microseconds) since unix epoch when the handler is // requested via RunHandlerPool::Get(). uint64 start_time_us() const { return start_time_us_; } - + int64 step_id() const { return step_id_; } void ScheduleInterOpClosure(std::function fn); + void ScheduleIntraOpClosure(std::function fn); - void Reset(); + void Reset(int64 step_id); RunHandlerPool::Impl* pool_impl() { return pool_impl_; } + RunHandlerThreadPool::ThreadWorkSource* tws() { return &tws_; } + private: - // Encoding/decoding logic for storing [start, limit) into a single - // uint_fast32_t int. We assume that pool_num_threads < (1 << 16). - const int kMaxPartitionBits = 16; - const int kMaxThreads = 1 << kMaxPartitionBits; + class ThreadPoolInterfaceWrapper : public thread::ThreadPoolInterface { + public: + explicit ThreadPoolInterfaceWrapper(Impl* run_handler_impl) + : run_handler_impl_(run_handler_impl) {} + ~ThreadPoolInterfaceWrapper() override {} + void Schedule(std::function fn) override; + int NumThreads() const override; + int CurrentThreadId() const override; - std::uint_fast32_t EncodePartition(std::uint_fast32_t start, - std::uint_fast32_t limit) { - return (start << kMaxPartitionBits) | limit; - } + private: + RunHandler::Impl* run_handler_impl_ = nullptr; + }; - void DecodePartition(std::uint_fast32_t val, std::uint_fast32_t* start, - std::uint_fast32_t* limit) { - *limit = val & (kMaxThreads - 1); - val >>= kMaxPartitionBits; - *start = val; - } - - std::atomic_uint_fast32_t inter_op_scheduling_range_; RunHandlerPool::Impl* pool_impl_; // NOT OWNED. uint64 start_time_us_; + int64 step_id_; + std::unique_ptr thread_pool_interface_; + RunHandlerThreadPool::ThreadWorkSource tws_; }; // Contains shared state across all run handlers present in the pool. Also @@ -82,30 +351,17 @@ class RunHandler::Impl { // This class is thread safe. class RunHandlerPool::Impl { public: - explicit Impl(int num_inter_op_threads) - : max_handlers_(128), - inter_op_thread_pool_(new thread::ThreadPool( - Env::Default(), ThreadOptions(), "inter_op", num_inter_op_threads)), + explicit Impl(int num_inter_op_threads, int num_intra_op_threads) + : max_handlers_(kMaxConcurrentHandlers), + run_handler_thread_pool_(new RunHandlerThreadPool( + num_inter_op_threads, num_intra_op_threads, Env::Default(), + ThreadOptions(), "tf_run_handler_pool")), iterations_(0) { VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_; for (int i = 0; i < max_handlers_; ++i) { handlers_.emplace_back(new RunHandler::Impl(this)); free_handlers_.push_back(handlers_.back().get()); } - - std::vector> steal_partitions( - num_inter_op_threads); - std::vector start_vec(num_inter_op_threads); - std::vector end_vec(num_inter_op_threads); - - ComputeInterOpStealingRanges(num_inter_op_threads, kMinThreadsPerDomain, - &start_vec, &end_vec); - for (int i = 0; i < num_inter_op_threads; ++i) { - steal_partitions[i] = std::make_pair(start_vec[i], end_vec[i]); - VLOG(1) << "Steal partition i: " << i << " steal_start: " << start_vec[i] - << " steal_end: " << end_vec[i]; - } - inter_op_thread_pool_->SetStealPartitions(steal_partitions); } ~Impl() { @@ -116,11 +372,11 @@ class RunHandlerPool::Impl { DCHECK_EQ(sorted_active_handlers_.size(), 0); } - thread::ThreadPool* inter_op_thread_pool() const { - return inter_op_thread_pool_.get(); + RunHandlerThreadPool* run_handler_thread_pool() { + return run_handler_thread_pool_.get(); } - std::unique_ptr Get() LOCKS_EXCLUDED(mu_) { + std::unique_ptr Get(int64 step_id) LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); while (free_handlers_.empty()) { one_handler_free_.wait(l); @@ -128,7 +384,7 @@ class RunHandlerPool::Impl { // Remove the last entry from free_handlers_ and add to the end of // sorted_active_handlers_. auto* handler_impl = free_handlers_.back(); - handler_impl->Reset(); + handler_impl->Reset(step_id); // Sortedness isn't violated if we simply add at the end of the list, since // handlers are expected to be obtained in increasing order of time. sorted_active_handlers_.push_back(handler_impl); @@ -144,6 +400,9 @@ class RunHandlerPool::Impl { mutex_lock l(mu_); DCHECK_GT(sorted_active_handlers_.size(), 0); + CHECK_EQ(handler->tws()->blocking_work_queue.Size(), 0); + CHECK_EQ(handler->tws()->non_blocking_work_queue.Size(), 0); + uint64 now = tensorflow::Env::Default()->NowMicros(); double elapsed = (now - handler->start_time_us()) / 1000.0; time_hist_.Add(elapsed); @@ -176,26 +435,16 @@ class RunHandlerPool::Impl { // inference). const int max_handlers_; - // Minimum number of threads allocated to process a request. - const int kMinThreadsPerRequest = 3; - - // Minmum number of threads in a steal domain. Each thread will first try - // to steal from threads in the same domain before stealing from threads - // in different domains. - const int kMinThreadsPerDomain = 2 * kMinThreadsPerRequest; - - // Thread safe part. - const std::unique_ptr inter_op_thread_pool_; - + std::unique_ptr run_handler_thread_pool_; // Thread compatible part used only by lock under RunHandlerPool. // Handlers are sorted by start time. + // TODO(azaks): sort by the remaining latency budget. std::vector sorted_active_handlers_ GUARDED_BY(mu_); std::vector free_handlers_ GUARDED_BY(mu_); std::vector> handlers_ GUARDED_BY(mu_); // Histogram of elapsed runtime of every handler (in ms). histogram::Histogram time_hist_ GUARDED_BY(mu_); - std::vector inter_op_start_ GUARDED_BY(mu_); - std::vector inter_op_limit_ GUARDED_BY(mu_); + int64 iterations_ GUARDED_BY(mu_); condition_variable one_handler_free_; mutex mu_; @@ -204,63 +453,98 @@ class RunHandlerPool::Impl { void RunHandlerPool::Impl::RecomputePoolStatsLocked() { int num_active_requests = sorted_active_handlers_.size(); if (num_active_requests == 0) return; + Eigen::MaxSizeVector + thread_work_sources(num_active_requests); - int num_threads = inter_op_thread_pool_->NumThreads(); - - inter_op_start_.resize(num_active_requests); - inter_op_limit_.resize(num_active_requests); - - ComputeInterOpSchedulingRanges(num_active_requests, num_threads, - kMinThreadsPerRequest, &inter_op_start_, - &inter_op_limit_); + thread_work_sources.resize(num_active_requests); for (int i = 0; i < num_active_requests; ++i) { - sorted_active_handlers_[i]->set_inter_op_scheduling_range( - inter_op_start_[i], inter_op_limit_[i]); + thread_work_sources[i] = sorted_active_handlers_[i]->tws(); + } + for (int i = 0; i < run_handler_thread_pool()->NumThreads(); ++i) { + VLOG(2) << "Setting work for tid = " << i; + run_handler_thread_pool()->SetThreadWorkSources(i, thread_work_sources); } - if (iterations_++ % 5000 == 0 && VLOG_IS_ON(1)) { + if (iterations_++ % 50000 == 10 && VLOG_IS_ON(1)) { VLOG(1) << "Printing time histogram: " << time_hist_.ToString(); VLOG(1) << "Active session runs: " << num_active_requests; uint64 now = tensorflow::Env::Default()->NowMicros(); - string ranges_str = ""; string times_str = ""; + string ids_str = ""; for (int i = 0; i < num_active_requests; ++i) { if (i > 0) { times_str += " "; - ranges_str += " "; + ids_str += " "; } times_str += strings::StrCat( (now - sorted_active_handlers_[i]->start_time_us()) / 1000.0, " ms."); - ranges_str += strings::StrCat("[", inter_op_start_[i], ", ", - inter_op_limit_[i], ")"); + ids_str += + strings::StrCat(sorted_active_handlers_[i]->tws()->traceme_id.load( + std::memory_order_relaxed)); } VLOG(1) << "Elapsed times are: " << times_str; - VLOG(1) << "Ranges are: " << ranges_str; + VLOG(1) << "Step ids are: " << ids_str; } } -void RunHandler::Impl::ScheduleInterOpClosure(std::function fn) { - std::uint_fast32_t start = 0, limit = 0; - DecodePartition(inter_op_scheduling_range(), &start, &limit); - DCHECK_LT(start, limit); - pool_impl_->inter_op_thread_pool()->ScheduleWithHint(std::move(fn), start, - limit); +// It is important to return a value such as: +// CurrentThreadId() in [0, NumThreads) +int RunHandler::Impl::ThreadPoolInterfaceWrapper::NumThreads() const { + return run_handler_impl_->pool_impl_->run_handler_thread_pool()->NumThreads(); } -void RunHandler::Impl::Reset() { - set_inter_op_scheduling_range( - 0, pool_impl_->inter_op_thread_pool()->NumThreads()); +int RunHandler::Impl::ThreadPoolInterfaceWrapper::CurrentThreadId() const { + return run_handler_impl_->pool_impl_->run_handler_thread_pool() + ->CurrentThreadId(); +} + +void RunHandler::Impl::ThreadPoolInterfaceWrapper::Schedule( + std::function fn) { + return run_handler_impl_->ScheduleIntraOpClosure(std::move(fn)); +} + +RunHandler::Impl::Impl(RunHandlerPool::Impl* pool_impl) + : pool_impl_(pool_impl) { + thread_pool_interface_.reset(new ThreadPoolInterfaceWrapper(this)); + Reset(0); +} + +void RunHandler::Impl::ScheduleInterOpClosure(std::function fn) { + VLOG(3) << "Scheduling inter work for " + << tws()->traceme_id.load(std::memory_order_relaxed); + pool_impl_->run_handler_thread_pool()->AddWorkToQueue( + &tws()->blocking_work_queue, &tws()->blocking_mu, true, + &tws()->traceme_id, std::move(fn)); +} + +void RunHandler::Impl::ScheduleIntraOpClosure(std::function fn) { + VLOG(3) << "Scheduling inter work for " + << tws()->traceme_id.load(std::memory_order_relaxed); + pool_impl_->run_handler_thread_pool()->AddWorkToQueue( + &tws()->non_blocking_work_queue, &tws()->non_blocking_mu, false, + &tws()->traceme_id, std::move(fn)); +} + +void RunHandler::Impl::Reset(int64 step_id) { start_time_us_ = tensorflow::Env::Default()->NowMicros(); + step_id_ = step_id; + tws_.traceme_id = step_id; } RunHandlerPool::RunHandlerPool(int num_inter_op_threads) - : impl_(new Impl(num_inter_op_threads)) {} + : impl_(new Impl(num_inter_op_threads, 0)) {} + +RunHandlerPool::RunHandlerPool(int num_inter_op_threads, + int num_intra_op_threads) + : impl_(new Impl(num_inter_op_threads, num_intra_op_threads)) {} RunHandlerPool::~RunHandlerPool() {} -std::unique_ptr RunHandlerPool::Get() { return impl_->Get(); } +std::unique_ptr RunHandlerPool::Get(int64 step_id) { + return impl_->Get(step_id); +} RunHandler::RunHandler(Impl* impl) : impl_(impl) {} @@ -268,5 +552,10 @@ void RunHandler::ScheduleInterOpClosure(std::function fn) { impl_->ScheduleInterOpClosure(std::move(fn)); } +thread::ThreadPoolInterface* RunHandler::AsIntraThreadPoolInterface() { + return impl_->thread_pool_interface(); +} + RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); } + } // namespace tensorflow diff --git a/tensorflow/core/framework/run_handler.h b/tensorflow/core/framework/run_handler.h index 72fa6301b47..5c5d96e52ea 100644 --- a/tensorflow/core/framework/run_handler.h +++ b/tensorflow/core/framework/run_handler.h @@ -22,6 +22,10 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/protobuf/config.pb.h" +namespace Eigen { +struct ThreadPoolDevice; +} + namespace tensorflow { class RunHandler; @@ -46,6 +50,8 @@ class RunHandler; class RunHandlerPool { public: explicit RunHandlerPool(int num_inter_op_threads); + + RunHandlerPool(int num_inter_op_threads, int num_intra_op_threads); ~RunHandlerPool(); // Returns an inactive RunHandler from the pool. @@ -56,7 +62,7 @@ class RunHandlerPool { // unique_ptr is destroyed. // // Will block unless there is an inactive handler. - std::unique_ptr Get(); + std::unique_ptr Get(int64 step_id = 0); private: class Impl; @@ -65,8 +71,10 @@ class RunHandlerPool { std::unique_ptr impl_; }; -// RunHandler can be used to schedule inter-op closures to run on a global pool -// shared across all Session::Run(s). +// RunHandler can be used to schedule inter/intra-op closures to run on a global +// pool shared across all Session::Run(s). The closures are enqueued to a +// handler specific queue, from which the work is stolen in a priority order +// (time of the Get() call). // // It can only be created via RunHandlerPool::Get(). // @@ -78,6 +86,7 @@ class RunHandlerPool { class RunHandler { public: void ScheduleInterOpClosure(std::function fn); + thread::ThreadPoolInterface* AsIntraThreadPoolInterface(); ~RunHandler(); diff --git a/tensorflow/core/framework/run_handler_test.cc b/tensorflow/core/framework/run_handler_test.cc new file mode 100644 index 00000000000..84dcee2adc3 --- /dev/null +++ b/tensorflow/core/framework/run_handler_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/run_handler.h" + +#include +#include + +#define EIGEN_USE_THREADS +#include "absl/memory/memory.h" +#include "absl/synchronization/barrier.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(RunHandlerUtilTest, TestBasicScheduling) { + int num_threads = 2; + int num_handlers = 10; + + std::unique_ptr pool(new RunHandlerPool(num_threads)); + + // RunHandler has 2 * num_threads (inter + intra) - + // all should be able to run concurrently. + absl::Barrier barrier1(num_threads); + absl::Barrier barrier2(num_threads); + + BlockingCounter counter(2 * num_handlers * num_threads); + + int num_test_threads = 10; + thread::ThreadPool test_pool(Env::Default(), "test", num_test_threads); + for (int i = 0; i < 10; ++i) { + test_pool.Schedule([&counter, &barrier1, &barrier2, &pool, i, + num_threads]() { + auto handler = pool->Get(); + BlockingCounter local_counter(2 * num_threads); + auto intra_thread_pool = handler->AsIntraThreadPoolInterface(); + + for (int j = 0; j < num_threads; ++j) { + handler->ScheduleInterOpClosure( + [&local_counter, &counter, &barrier1, i]() { + if (i == 2) { + barrier1.Block(); + } + counter.DecrementCount(); + local_counter.DecrementCount(); + }); + intra_thread_pool->Schedule([&local_counter, &counter, &barrier2, i]() { + if (i == 9) { + barrier2.Block(); + } + counter.DecrementCount(); + local_counter.DecrementCount(); + }); + } + local_counter.Wait(); + }); + } + counter.Wait(); +} + +} // namespace +} // namespace tensorflow