[tf.data] Reimplement tf.contrib.data.get_single_element() as an async op.

The current `ToSingleElementOp` kernel has a synchronous
implementation, and yet it can block an inter-op threadpool thread (in
`iterator->GetNext()`) while depending on another (e.g. if the
iterator calls a TensorFlow function to produce an element). This can
lead to deadlock if the number of inter-op threadpool threads is less
than or equal to the number of concurrent activations of that
kernel. This change fixes that deadlock by moving the blocking
computation onto a background thread.

PiperOrigin-RevId: 179067816
This commit is contained in:
Derek Murray 2017-12-14 11:08:00 -08:00 committed by TensorFlower Gardener
parent 0761849e0a
commit 3d854a744d

View File

@ -448,40 +448,60 @@ class MakeIteratorOp : public OpKernel {
}
};
class ToSingleElementOp : public OpKernel {
class ToSingleElementOp : public AsyncOpKernel {
public:
explicit ToSingleElementOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
explicit ToSingleElementOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
thread_pool_(new thread::ThreadPool(
ctx->env(), ThreadOptions(),
strings::StrCat("to_single_element_op_thread_",
SanitizeThreadSuffix(name())),
1 /* num_threads */, false /* low_latency_hint */)) {}
void Compute(OpKernelContext* ctx) override {
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
auto iterator = dataset->MakeIterator("SingleElementIterator");
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
thread_pool_->Schedule([ctx, done]() {
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
auto iterator = dataset->MakeIterator("SingleElementIterator");
IteratorContext::Params params;
params.env = ctx->env();
params.runner = *(ctx->runner());
IteratorContext iter_ctx(std::move(params));
IteratorContext::Params params;
params.env = ctx->env();
params.runner = *(ctx->runner());
IteratorContext iter_ctx(std::move(params));
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence;
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence;
OP_REQUIRES_OK(ctx,
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
OP_REQUIRES(ctx, !end_of_sequence,
errors::InvalidArgument("Dataset was empty."));
OP_REQUIRES_OK_ASYNC(
ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
done);
OP_REQUIRES_ASYNC(ctx, !end_of_sequence,
errors::InvalidArgument("Dataset was empty."), done);
for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
ctx->set_output(i, components[i]);
}
for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
ctx->set_output(i, components[i]);
}
components.clear();
OP_REQUIRES_OK(ctx,
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
OP_REQUIRES(ctx, end_of_sequence,
errors::InvalidArgument("Dataset had more than one element."));
components.clear();
OP_REQUIRES_OK_ASYNC(
ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
done);
OP_REQUIRES_ASYNC(
ctx, end_of_sequence,
errors::InvalidArgument("Dataset had more than one element."), done);
done();
});
}
private:
std::unique_ptr<thread::ThreadPool> thread_pool_;
};
class OneShotIteratorOp : public AsyncOpKernel {