Merge pull request #23233 from feihugis:Issue-Max_Parallelism_Thread_For_Parallel_Map_Dataset

PiperOrigin-RevId: 220207837
This commit is contained in:
TensorFlower Gardener 2018-11-05 18:19:06 -08:00
commit a7faefeb5c
5 changed files with 18 additions and 11 deletions

View File

@ -279,12 +279,15 @@ class IteratorContext {
lib(ctx->lib()),
model(ctx->model()),
runner(*(ctx->runner())),
runner_threadpool_size(ctx->runner_threadpool_size()),
stats_aggregator(ctx->stats_aggregator()) {}
explicit Params(OpKernelContext* ctx)
: env(ctx->env()),
lib(ctx->function_library()),
runner(*(ctx->runner())) {
runner(*(ctx->runner())),
runner_threadpool_size(
ctx->device()->tensorflow_cpu_worker_threads()->num_threads) {
// NOTE: need reinterpret_cast because function.h forward-declares Device.
DeviceBase* device =
reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
@ -311,6 +314,9 @@ class IteratorContext {
// Function call support.
std::function<void(std::function<void()>)> runner = nullptr;
// Number of threads used for executing user-defined functions.
int32 runner_threadpool_size = 0;
// The `StatsAggregator` object to record statistics about the iterator.
std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
};
@ -343,6 +349,8 @@ class IteratorContext {
return &params_.runner;
}
int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
std::shared_ptr<StatsAggregator> stats_aggregator() {
return params_.stats_aggregator;
}

View File

@ -201,7 +201,7 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == kAutoTune) {
num_parallel_calls_->value = port::NumSchedulableCPUs();
num_parallel_calls_->value = ctx->runner_threadpool_size();
num_parallel_calls_->tunable = true;
}
TF_RETURN_IF_ERROR(
@ -244,7 +244,7 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return model::MakeAsyncKnownRatioNode(
std::move(args), dataset()->batch_size_,
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
/*max=*/port::NumSchedulableCPUs())});
/*max=*/ctx->runner_threadpool_size())});
}
Status SaveInternal(IteratorStateWriter* writer) override {

View File

@ -47,6 +47,8 @@ class ThreadPoolResource : public ResourceBase {
}
}
int32 NumThreads() { return thread_pool_.NumThreads(); }
string DebugString() override { return "ThreadPoolResource"; }
private:
@ -196,6 +198,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
params.runner = [pool](std::function<void()> c) {
pool->Schedule(std::move(c));
};
params.runner_threadpool_size = pool->NumThreads();
IteratorContext iter_ctx(params);
return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
}

View File

@ -262,9 +262,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == kAutoTune) {
// TODO(jsimsa): Surface the number of threads used by `ctx->runner()`
// and use it here for the default.
num_parallel_calls_->value = port::NumSchedulableCPUs();
num_parallel_calls_->value = ctx->runner_threadpool_size();
num_parallel_calls_->tunable = true;
}
TF_RETURN_IF_ERROR(
@ -298,7 +296,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return model::MakeAsyncKnownRatioNode(
std::move(args), dataset()->batch_size_,
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
/*max=*/port::NumSchedulableCPUs())});
/*max=*/ctx->runner_threadpool_size())});
}
Status SaveInternal(IteratorStateWriter* writer) override {

View File

@ -65,9 +65,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == kAutoTune) {
// TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
// use it here for the default.
num_parallel_calls_->value = port::NumSchedulableCPUs();
num_parallel_calls_->value = ctx->runner_threadpool_size();
num_parallel_calls_->tunable = true;
}
TF_RETURN_IF_ERROR(
@ -103,7 +101,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
std::move(args),
/*ratio=*/1,
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
/*max=*/port::NumSchedulableCPUs())});
/*max=*/ctx->runner_threadpool_size())});
}
Status SaveInternal(IteratorStateWriter* writer) override {