Enforce max_parallelism in work sharder by reverting to old sharding code if max_parallelism is less than number of workers in the pool.

Change: 123139558
This commit is contained in:
A. Unique TensorFlower 2016-05-24 12:26:38 -08:00 committed by TensorFlower Gardener
parent d44ed34cd2
commit 67ddfa5b34
4 changed files with 30 additions and 33 deletions

View File

@ -87,14 +87,11 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
num_threads_(num_threads) {}
void ParallelFor(int64 total, int64 cost_per_unit,
std::function<void(int64, int64)> fn,
int32 max_parallelism = kint32max) {
std::function<void(int64, int64)> fn) {
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
CHECK_GT(max_parallelism, 0);
CHECK_GE(total, 0);
CHECK_EQ(total, (int64)(Eigen::Index)total);
Eigen::ThreadPoolDevice device(this,
std::min(num_threads_, max_parallelism));
Eigen::ThreadPoolDevice device(this, num_threads_);
device.parallelFor(
total, Eigen::TensorOpCost(0, 0, cost_per_unit),
[&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
@ -103,6 +100,8 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
#endif
}
int NumThreads() const { return num_threads_; };
const int num_threads_;
};
@ -114,11 +113,12 @@ struct ThreadPool::Impl {
~Impl();
void Schedule(std::function<void()> fn);
void ParallelFor(int64 total, int64 cost_per_unit,
std::function<void(int64, int64)> fn,
int32 max_parallelism = kint32max) {
std::function<void(int64, int64)> fn) {
CHECK(0); // should not be used with the old thread pool
}
int NumThreads() const { return threads_.size(); };
private:
struct Waiter {
condition_variable cv;
@ -242,10 +242,11 @@ void ThreadPool::Schedule(std::function<void()> fn) {
}
void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
std::function<void(int64, int64)> fn,
int32 max_parallelism) {
impl_->ParallelFor(total, cost_per_unit, std::move(fn), max_parallelism);
std::function<void(int64, int64)> fn) {
impl_->ParallelFor(total, cost_per_unit, std::move(fn));
}
int ThreadPool::NumThreads() const { return impl_->NumThreads(); }
} // namespace thread
} // namespace tensorflow

View File

@ -51,12 +51,11 @@ class ThreadPool {
// having roughly "cost_per_unit" cost, in cycles. Each unit of work is
// indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work
// and the total cost of each shard is roughly the same.
// Max_parallelism optionally caps the number of threads used.
//
// REQUIRES: max_parallelism > 0.
void ParallelFor(int64 total, int64 cost_per_unit,
std::function<void(int64, int64)> fn,
int32 max_parallelism = kint32max);
std::function<void(int64, int64)> fn);
// Returns the number of threads in the pool.
int NumThreads() const;
struct Impl;

View File

@ -66,22 +66,17 @@ TEST(ThreadPool, ParallelFor) {
const int kWorkItems = 15;
bool work[kWorkItems];
ThreadPool pool(Env::Default(), "test", num_threads);
for (int max_parallelism = 1; max_parallelism <= kNumThreads + 1;
max_parallelism++) {
for (int i = 0; i < kWorkItems; i++) {
work[i] = false;
}
pool.ParallelFor(kWorkItems, kHugeCost,
[&work](int64 begin, int64 end) {
for (int64 i = begin; i < end; ++i) {
ASSERT_FALSE(work[i]);
work[i] = true;
}
},
max_parallelism);
for (int i = 0; i < kWorkItems; i++) {
ASSERT_TRUE(work[i]);
for (int i = 0; i < kWorkItems; i++) {
work[i] = false;
}
pool.ParallelFor(kWorkItems, kHugeCost, [&work](int64 begin, int64 end) {
for (int64 i = begin; i < end; ++i) {
ASSERT_FALSE(work[i]);
work[i] = true;
}
});
for (int i = 0; i < kWorkItems; i++) {
ASSERT_TRUE(work[i]);
}
}
}

View File

@ -32,8 +32,11 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
return;
}
#ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
workers->ParallelFor(total, cost_per_unit, work, max_parallelism);
#else
if (max_parallelism >= workers->NumThreads()) {
workers->ParallelFor(total, cost_per_unit, work);
return;
}
#endif
cost_per_unit = std::max(1LL, cost_per_unit);
// We shard [0, total) into "num_shards" shards.
// 1 <= num_shards <= num worker threads
@ -71,7 +74,6 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
// Inline execute the 1st shard.
work(0, std::min(block_size, total));
counter.Wait();
#endif
}
} // end namespace tensorflow