diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc index 2ccee283b4c..af5d44565cb 100644 --- a/tensorflow/core/lib/core/threadpool.cc +++ b/tensorflow/core/lib/core/threadpool.cc @@ -87,14 +87,11 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl { num_threads_(num_threads) {} void ParallelFor(int64 total, int64 cost_per_unit, - std::function fn, - int32 max_parallelism = kint32max) { + std::function fn) { #ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL - CHECK_GT(max_parallelism, 0); CHECK_GE(total, 0); CHECK_EQ(total, (int64)(Eigen::Index)total); - Eigen::ThreadPoolDevice device(this, - std::min(num_threads_, max_parallelism)); + Eigen::ThreadPoolDevice device(this, num_threads_); device.parallelFor( total, Eigen::TensorOpCost(0, 0, cost_per_unit), [&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); }); @@ -103,6 +100,8 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl { #endif } + int NumThreads() const { return num_threads_; }; + const int num_threads_; }; @@ -114,11 +113,12 @@ struct ThreadPool::Impl { ~Impl(); void Schedule(std::function fn); void ParallelFor(int64 total, int64 cost_per_unit, - std::function fn, - int32 max_parallelism = kint32max) { + std::function fn) { CHECK(0); // should not be used with the old thread pool } + int NumThreads() const { return threads_.size(); }; + private: struct Waiter { condition_variable cv; @@ -242,10 +242,11 @@ void ThreadPool::Schedule(std::function fn) { } void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit, - std::function fn, - int32 max_parallelism) { - impl_->ParallelFor(total, cost_per_unit, std::move(fn), max_parallelism); + std::function fn) { + impl_->ParallelFor(total, cost_per_unit, std::move(fn)); } +int ThreadPool::NumThreads() const { return impl_->NumThreads(); } + } // namespace thread } // namespace tensorflow diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h index 30049fb2520..fe7f2d0d86b 100644 --- a/tensorflow/core/lib/core/threadpool.h +++ b/tensorflow/core/lib/core/threadpool.h @@ -51,12 +51,11 @@ class ThreadPool { // having roughly "cost_per_unit" cost, in cycles. Each unit of work is // indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work // and the total cost of each shard is roughly the same. - // Max_parallelism optionally caps the number of threads used. - // - // REQUIRES: max_parallelism > 0. void ParallelFor(int64 total, int64 cost_per_unit, - std::function fn, - int32 max_parallelism = kint32max); + std::function fn); + + // Returns the number of threads in the pool. + int NumThreads() const; struct Impl; diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc index 524af800d87..5043e54459a 100644 --- a/tensorflow/core/lib/core/threadpool_test.cc +++ b/tensorflow/core/lib/core/threadpool_test.cc @@ -66,22 +66,17 @@ TEST(ThreadPool, ParallelFor) { const int kWorkItems = 15; bool work[kWorkItems]; ThreadPool pool(Env::Default(), "test", num_threads); - for (int max_parallelism = 1; max_parallelism <= kNumThreads + 1; - max_parallelism++) { - for (int i = 0; i < kWorkItems; i++) { - work[i] = false; - } - pool.ParallelFor(kWorkItems, kHugeCost, - [&work](int64 begin, int64 end) { - for (int64 i = begin; i < end; ++i) { - ASSERT_FALSE(work[i]); - work[i] = true; - } - }, - max_parallelism); - for (int i = 0; i < kWorkItems; i++) { - ASSERT_TRUE(work[i]); + for (int i = 0; i < kWorkItems; i++) { + work[i] = false; + } + pool.ParallelFor(kWorkItems, kHugeCost, [&work](int64 begin, int64 end) { + for (int64 i = begin; i < end; ++i) { + ASSERT_FALSE(work[i]); + work[i] = true; } + }); + for (int i = 0; i < kWorkItems; i++) { + ASSERT_TRUE(work[i]); } } } diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc index 192b1dcd7c4..1c454d08fa6 100644 --- a/tensorflow/core/util/work_sharder.cc +++ b/tensorflow/core/util/work_sharder.cc @@ -32,8 +32,11 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total, return; } #ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL - workers->ParallelFor(total, cost_per_unit, work, max_parallelism); -#else + if (max_parallelism >= workers->NumThreads()) { + workers->ParallelFor(total, cost_per_unit, work); + return; + } +#endif cost_per_unit = std::max(1LL, cost_per_unit); // We shard [0, total) into "num_shards" shards. // 1 <= num_shards <= num worker threads @@ -71,7 +74,6 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total, // Inline execute the 1st shard. work(0, std::min(block_size, total)); counter.Wait(); -#endif } } // end namespace tensorflow