From 3d854a744d1236944eb0ecdc172b1825ace565e1 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 14 Dec 2017 11:08:00 -0800 Subject: [PATCH] [tf.data] Reimplement `tf.contrib.data.get_single_element()` as an async op. The current `ToSingleElementOp` kernel has a synchronous implementation, and yet it can block an inter-op threadpool thread (in `iterator->GetNext()`) while depending on another (e.g. if the iterator calls a TensorFlow function to produce an element). This can lead to deadlock if the number of inter-op threadpool threads is less than or equal to the number of concurrent activations of that kernel. This change fixes that deadlock by moving the blocking computation onto a background thread. PiperOrigin-RevId: 179067816 --- tensorflow/core/kernels/iterator_ops.cc | 72 ++++++++++++++++--------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc index 439775157bc..4e81d40a826 100644 --- a/tensorflow/core/kernels/iterator_ops.cc +++ b/tensorflow/core/kernels/iterator_ops.cc @@ -448,40 +448,60 @@ class MakeIteratorOp : public OpKernel { } }; -class ToSingleElementOp : public OpKernel { +class ToSingleElementOp : public AsyncOpKernel { public: - explicit ToSingleElementOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + explicit ToSingleElementOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("to_single_element_op_thread_", + SanitizeThreadSuffix(name())), + 1 /* num_threads */, false /* low_latency_hint */)) {} - void Compute(OpKernelContext* ctx) override { - DatasetBase* dataset; - OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); - auto iterator = dataset->MakeIterator("SingleElementIterator"); + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + // The call to `iterator->GetNext()` may block and depend on an + // inter-op thread pool thread, so we issue the call from the + // owned thread pool. + thread_pool_->Schedule([ctx, done]() { + DatasetBase* dataset; + OP_REQUIRES_OK_ASYNC( + ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done); + auto iterator = dataset->MakeIterator("SingleElementIterator"); - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); - IteratorContext iter_ctx(std::move(params)); + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = *(ctx->runner()); + IteratorContext iter_ctx(std::move(params)); - std::vector components; - components.reserve(dataset->output_dtypes().size()); - bool end_of_sequence; + std::vector components; + components.reserve(dataset->output_dtypes().size()); + bool end_of_sequence; - OP_REQUIRES_OK(ctx, - iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); - OP_REQUIRES(ctx, !end_of_sequence, - errors::InvalidArgument("Dataset was empty.")); + OP_REQUIRES_OK_ASYNC( + ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), + done); + OP_REQUIRES_ASYNC(ctx, !end_of_sequence, + errors::InvalidArgument("Dataset was empty."), done); - for (int i = 0; i < components.size(); ++i) { - // TODO(mrry): Check that the shapes match the shape attrs. - ctx->set_output(i, components[i]); - } + for (int i = 0; i < components.size(); ++i) { + // TODO(mrry): Check that the shapes match the shape attrs. + ctx->set_output(i, components[i]); + } - components.clear(); - OP_REQUIRES_OK(ctx, - iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); - OP_REQUIRES(ctx, end_of_sequence, - errors::InvalidArgument("Dataset had more than one element.")); + components.clear(); + OP_REQUIRES_OK_ASYNC( + ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), + done); + OP_REQUIRES_ASYNC( + ctx, end_of_sequence, + errors::InvalidArgument("Dataset had more than one element."), done); + + done(); + }); } + + private: + std::unique_ptr thread_pool_; }; class OneShotIteratorOp : public AsyncOpKernel {