From 50b1c27acab7b6b4a220e30eb076155107c1f8f3 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Thu, 21 Jan 2021 09:40:52 -0800 Subject: [PATCH] [tf.data] Relax locking in the threading dataset ops. Concurrent calls to threading dataset iterator `GetNext` calls do not need to be serialized. We only need to serialize calls to `GetNext` w.r.t. to saving and restoring the iterator state. PiperOrigin-RevId: 353038539 Change-Id: I03d027350129ebdb1995166821fb58ef6748b27a --- .../kernels/data/experimental/threadpool_dataset_op.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index 111d7b2fec2..af2a27d22e6 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/util/work_sharder.h" @@ -210,7 +211,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return input_impl_->GetNext(IteratorContext(CreateParams(ctx)), out_tensors, end_of_sequence); } @@ -350,7 +351,7 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel { auto max_parallelism = dataset()->max_intra_op_parallelism_; params.runner = RunnerWithMaxParallelism(*ctx->runner(), max_parallelism); - mutex_lock l(mu_); + tf_shared_lock l(mu_); return input_impl_->GetNext(IteratorContext{std::move(params)}, out_tensors, end_of_sequence); } @@ -480,7 +481,7 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel { pool->Schedule(std::move(c)); }; params.runner_threadpool_size = dataset()->num_threads_; - mutex_lock l(mu_); + tf_shared_lock l(mu_); return input_impl_->GetNext(IteratorContext{std::move(params)}, out_tensors, end_of_sequence); }