Enforce max_parallelism in work sharder by reverting to old sharding code if max_parallelism is less than number of workers in the pool.
Change: 123139558
This commit is contained in:
parent
d44ed34cd2
commit
67ddfa5b34
@ -87,14 +87,11 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
|
||||
num_threads_(num_threads) {}
|
||||
|
||||
void ParallelFor(int64 total, int64 cost_per_unit,
|
||||
std::function<void(int64, int64)> fn,
|
||||
int32 max_parallelism = kint32max) {
|
||||
std::function<void(int64, int64)> fn) {
|
||||
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
|
||||
CHECK_GT(max_parallelism, 0);
|
||||
CHECK_GE(total, 0);
|
||||
CHECK_EQ(total, (int64)(Eigen::Index)total);
|
||||
Eigen::ThreadPoolDevice device(this,
|
||||
std::min(num_threads_, max_parallelism));
|
||||
Eigen::ThreadPoolDevice device(this, num_threads_);
|
||||
device.parallelFor(
|
||||
total, Eigen::TensorOpCost(0, 0, cost_per_unit),
|
||||
[&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
|
||||
@ -103,6 +100,8 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
|
||||
#endif
|
||||
}
|
||||
|
||||
int NumThreads() const { return num_threads_; };
|
||||
|
||||
const int num_threads_;
|
||||
};
|
||||
|
||||
@ -114,11 +113,12 @@ struct ThreadPool::Impl {
|
||||
~Impl();
|
||||
void Schedule(std::function<void()> fn);
|
||||
void ParallelFor(int64 total, int64 cost_per_unit,
|
||||
std::function<void(int64, int64)> fn,
|
||||
int32 max_parallelism = kint32max) {
|
||||
std::function<void(int64, int64)> fn) {
|
||||
CHECK(0); // should not be used with the old thread pool
|
||||
}
|
||||
|
||||
int NumThreads() const { return threads_.size(); };
|
||||
|
||||
private:
|
||||
struct Waiter {
|
||||
condition_variable cv;
|
||||
@ -242,10 +242,11 @@ void ThreadPool::Schedule(std::function<void()> fn) {
|
||||
}
|
||||
|
||||
void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
|
||||
std::function<void(int64, int64)> fn,
|
||||
int32 max_parallelism) {
|
||||
impl_->ParallelFor(total, cost_per_unit, std::move(fn), max_parallelism);
|
||||
std::function<void(int64, int64)> fn) {
|
||||
impl_->ParallelFor(total, cost_per_unit, std::move(fn));
|
||||
}
|
||||
|
||||
int ThreadPool::NumThreads() const { return impl_->NumThreads(); }
|
||||
|
||||
} // namespace thread
|
||||
} // namespace tensorflow
|
||||
|
@ -51,12 +51,11 @@ class ThreadPool {
|
||||
// having roughly "cost_per_unit" cost, in cycles. Each unit of work is
|
||||
// indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work
|
||||
// and the total cost of each shard is roughly the same.
|
||||
// Max_parallelism optionally caps the number of threads used.
|
||||
//
|
||||
// REQUIRES: max_parallelism > 0.
|
||||
void ParallelFor(int64 total, int64 cost_per_unit,
|
||||
std::function<void(int64, int64)> fn,
|
||||
int32 max_parallelism = kint32max);
|
||||
std::function<void(int64, int64)> fn);
|
||||
|
||||
// Returns the number of threads in the pool.
|
||||
int NumThreads() const;
|
||||
|
||||
struct Impl;
|
||||
|
||||
|
@ -66,22 +66,17 @@ TEST(ThreadPool, ParallelFor) {
|
||||
const int kWorkItems = 15;
|
||||
bool work[kWorkItems];
|
||||
ThreadPool pool(Env::Default(), "test", num_threads);
|
||||
for (int max_parallelism = 1; max_parallelism <= kNumThreads + 1;
|
||||
max_parallelism++) {
|
||||
for (int i = 0; i < kWorkItems; i++) {
|
||||
work[i] = false;
|
||||
}
|
||||
pool.ParallelFor(kWorkItems, kHugeCost,
|
||||
[&work](int64 begin, int64 end) {
|
||||
for (int64 i = begin; i < end; ++i) {
|
||||
ASSERT_FALSE(work[i]);
|
||||
work[i] = true;
|
||||
}
|
||||
},
|
||||
max_parallelism);
|
||||
for (int i = 0; i < kWorkItems; i++) {
|
||||
ASSERT_TRUE(work[i]);
|
||||
for (int i = 0; i < kWorkItems; i++) {
|
||||
work[i] = false;
|
||||
}
|
||||
pool.ParallelFor(kWorkItems, kHugeCost, [&work](int64 begin, int64 end) {
|
||||
for (int64 i = begin; i < end; ++i) {
|
||||
ASSERT_FALSE(work[i]);
|
||||
work[i] = true;
|
||||
}
|
||||
});
|
||||
for (int i = 0; i < kWorkItems; i++) {
|
||||
ASSERT_TRUE(work[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -32,8 +32,11 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
|
||||
return;
|
||||
}
|
||||
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
|
||||
workers->ParallelFor(total, cost_per_unit, work, max_parallelism);
|
||||
#else
|
||||
if (max_parallelism >= workers->NumThreads()) {
|
||||
workers->ParallelFor(total, cost_per_unit, work);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
cost_per_unit = std::max(1LL, cost_per_unit);
|
||||
// We shard [0, total) into "num_shards" shards.
|
||||
// 1 <= num_shards <= num worker threads
|
||||
@ -71,7 +74,6 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
|
||||
// Inline execute the 1st shard.
|
||||
work(0, std::min(block_size, total));
|
||||
counter.Wait();
|
||||
#endif
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user