[tf.data] Relax locking in the threading dataset ops.
Concurrent calls to threading dataset iterator `GetNext` calls do not need to be serialized. We only need to serialize calls to `GetNext` w.r.t. to saving and restoring the iterator state. PiperOrigin-RevId: 353038539 Change-Id: I03d027350129ebdb1995166821fb58ef6748b27a
This commit is contained in:
parent
63c386bc3a
commit
50b1c27aca
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
@ -210,7 +211,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
tf_shared_lock l(mu_);
|
||||
return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
|
||||
out_tensors, end_of_sequence);
|
||||
}
|
||||
@ -350,7 +351,7 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
||||
auto max_parallelism = dataset()->max_intra_op_parallelism_;
|
||||
params.runner =
|
||||
RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
|
||||
mutex_lock l(mu_);
|
||||
tf_shared_lock l(mu_);
|
||||
return input_impl_->GetNext(IteratorContext{std::move(params)},
|
||||
out_tensors, end_of_sequence);
|
||||
}
|
||||
@ -480,7 +481,7 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
pool->Schedule(std::move(c));
|
||||
};
|
||||
params.runner_threadpool_size = dataset()->num_threads_;
|
||||
mutex_lock l(mu_);
|
||||
tf_shared_lock l(mu_);
|
||||
return input_impl_->GetNext(IteratorContext{std::move(params)},
|
||||
out_tensors, end_of_sequence);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user