Merge pull request #23233 from feihugis:Issue-Max_Parallelism_Thread_For_Parallel_Map_Dataset
PiperOrigin-RevId: 220207837
This commit is contained in:
commit
a7faefeb5c
@ -279,12 +279,15 @@ class IteratorContext {
|
||||
lib(ctx->lib()),
|
||||
model(ctx->model()),
|
||||
runner(*(ctx->runner())),
|
||||
runner_threadpool_size(ctx->runner_threadpool_size()),
|
||||
stats_aggregator(ctx->stats_aggregator()) {}
|
||||
|
||||
explicit Params(OpKernelContext* ctx)
|
||||
: env(ctx->env()),
|
||||
lib(ctx->function_library()),
|
||||
runner(*(ctx->runner())) {
|
||||
runner(*(ctx->runner())),
|
||||
runner_threadpool_size(
|
||||
ctx->device()->tensorflow_cpu_worker_threads()->num_threads) {
|
||||
// NOTE: need reinterpret_cast because function.h forward-declares Device.
|
||||
DeviceBase* device =
|
||||
reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
|
||||
@ -311,6 +314,9 @@ class IteratorContext {
|
||||
// Function call support.
|
||||
std::function<void(std::function<void()>)> runner = nullptr;
|
||||
|
||||
// Number of threads used for executing user-defined functions.
|
||||
int32 runner_threadpool_size = 0;
|
||||
|
||||
// The `StatsAggregator` object to record statistics about the iterator.
|
||||
std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
|
||||
};
|
||||
@ -343,6 +349,8 @@ class IteratorContext {
|
||||
return ¶ms_.runner;
|
||||
}
|
||||
|
||||
int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
|
||||
|
||||
std::shared_ptr<StatsAggregator> stats_aggregator() {
|
||||
return params_.stats_aggregator;
|
||||
}
|
||||
|
||||
@ -201,7 +201,7 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(*mu_);
|
||||
if (num_parallel_calls_->value == kAutoTune) {
|
||||
num_parallel_calls_->value = port::NumSchedulableCPUs();
|
||||
num_parallel_calls_->value = ctx->runner_threadpool_size();
|
||||
num_parallel_calls_->tunable = true;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -244,7 +244,7 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
return model::MakeAsyncKnownRatioNode(
|
||||
std::move(args), dataset()->batch_size_,
|
||||
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
|
||||
/*max=*/port::NumSchedulableCPUs())});
|
||||
/*max=*/ctx->runner_threadpool_size())});
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
|
||||
@ -47,6 +47,8 @@ class ThreadPoolResource : public ResourceBase {
|
||||
}
|
||||
}
|
||||
|
||||
int32 NumThreads() { return thread_pool_.NumThreads(); }
|
||||
|
||||
string DebugString() override { return "ThreadPoolResource"; }
|
||||
|
||||
private:
|
||||
@ -196,6 +198,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
params.runner = [pool](std::function<void()> c) {
|
||||
pool->Schedule(std::move(c));
|
||||
};
|
||||
params.runner_threadpool_size = pool->NumThreads();
|
||||
IteratorContext iter_ctx(params);
|
||||
return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
@ -262,9 +262,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(*mu_);
|
||||
if (num_parallel_calls_->value == kAutoTune) {
|
||||
// TODO(jsimsa): Surface the number of threads used by `ctx->runner()`
|
||||
// and use it here for the default.
|
||||
num_parallel_calls_->value = port::NumSchedulableCPUs();
|
||||
num_parallel_calls_->value = ctx->runner_threadpool_size();
|
||||
num_parallel_calls_->tunable = true;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -298,7 +296,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
return model::MakeAsyncKnownRatioNode(
|
||||
std::move(args), dataset()->batch_size_,
|
||||
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
|
||||
/*max=*/port::NumSchedulableCPUs())});
|
||||
/*max=*/ctx->runner_threadpool_size())});
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
|
||||
@ -65,9 +65,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(*mu_);
|
||||
if (num_parallel_calls_->value == kAutoTune) {
|
||||
// TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
|
||||
// use it here for the default.
|
||||
num_parallel_calls_->value = port::NumSchedulableCPUs();
|
||||
num_parallel_calls_->value = ctx->runner_threadpool_size();
|
||||
num_parallel_calls_->tunable = true;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -103,7 +101,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
||||
std::move(args),
|
||||
/*ratio=*/1,
|
||||
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
|
||||
/*max=*/port::NumSchedulableCPUs())});
|
||||
/*max=*/ctx->runner_threadpool_size())});
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user