[tf.data] Relax locking in the threading dataset ops.

Concurrent calls to threading dataset iterator `GetNext` calls do not need to be serialized. We only need to serialize calls to `GetNext` w.r.t. to saving and restoring the iterator state.

PiperOrigin-RevId: 353038539
Change-Id: I03d027350129ebdb1995166821fb58ef6748b27a
This commit is contained in:
Jiri Simsa 2021-01-21 09:40:52 -08:00 committed by TensorFlower Gardener
parent 63c386bc3a
commit 50b1c27aca

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/work_sharder.h"
@ -210,7 +211,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
tf_shared_lock l(mu_);
return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
out_tensors, end_of_sequence);
}
@ -350,7 +351,7 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
auto max_parallelism = dataset()->max_intra_op_parallelism_;
params.runner =
RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
mutex_lock l(mu_);
tf_shared_lock l(mu_);
return input_impl_->GetNext(IteratorContext{std::move(params)},
out_tensors, end_of_sequence);
}
@ -480,7 +481,7 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
pool->Schedule(std::move(c));
};
params.runner_threadpool_size = dataset()->num_threads_;
mutex_lock l(mu_);
tf_shared_lock l(mu_);
return input_impl_->GetNext(IteratorContext{std::move(params)},
out_tensors, end_of_sequence);
}