[tf.data] Initialize num_parallel_calls of ParallelBatchDataset to be threadpool size if autotune is OFF but its value set to be AUTOTUNE by the user.

PiperOrigin-RevId: 356275240
Change-Id: Idaa5f75fb8c4eceb500826e9b9125fa3babf8774
This commit is contained in:
Jay Shi 2021-02-08 08:50:48 -08:00 committed by TensorFlower Gardener
parent af95667a98
commit 20da5546cb

View File

@ -181,7 +181,20 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = 1;
// If autotuning is enabled, we initialize the parallelism to 1 to
// avoid accidentally running the machine out of memory before the
// optimization can pick values that respect the memory budget.
//
// If autotuning is disabled but the transformation uses `AUTOTUNE`, we
// default the parallelism to the size of the threadpool used for
// executing the user-defined computation. If this causes OOM, the
// input pipeline should either enable autotuning, or replace
// `AUTOTUNE` with fixed parallelism.
if (TF_PREDICT_TRUE(ctx->model())) {
num_parallel_calls_->value = 1;
} else {
num_parallel_calls_->value = ctx->runner_threadpool_size();
}
}
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),