[tf.data] Reimplement tf.contrib.data.get_single_element()
as an async op.
The current `ToSingleElementOp` kernel has a synchronous implementation, and yet it can block an inter-op threadpool thread (in `iterator->GetNext()`) while depending on another (e.g. if the iterator calls a TensorFlow function to produce an element). This can lead to deadlock if the number of inter-op threadpool threads is less than or equal to the number of concurrent activations of that kernel. This change fixes that deadlock by moving the blocking computation onto a background thread. PiperOrigin-RevId: 179067816
This commit is contained in:
parent
0761849e0a
commit
3d854a744d
@ -448,40 +448,60 @@ class MakeIteratorOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
class ToSingleElementOp : public OpKernel {
|
||||
class ToSingleElementOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit ToSingleElementOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
explicit ToSingleElementOp(OpKernelConstruction* ctx)
|
||||
: AsyncOpKernel(ctx),
|
||||
thread_pool_(new thread::ThreadPool(
|
||||
ctx->env(), ThreadOptions(),
|
||||
strings::StrCat("to_single_element_op_thread_",
|
||||
SanitizeThreadSuffix(name())),
|
||||
1 /* num_threads */, false /* low_latency_hint */)) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
DatasetBase* dataset;
|
||||
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
auto iterator = dataset->MakeIterator("SingleElementIterator");
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
// The call to `iterator->GetNext()` may block and depend on an
|
||||
// inter-op thread pool thread, so we issue the call from the
|
||||
// owned thread pool.
|
||||
thread_pool_->Schedule([ctx, done]() {
|
||||
DatasetBase* dataset;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
|
||||
auto iterator = dataset->MakeIterator("SingleElementIterator");
|
||||
|
||||
IteratorContext::Params params;
|
||||
params.env = ctx->env();
|
||||
params.runner = *(ctx->runner());
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
IteratorContext::Params params;
|
||||
params.env = ctx->env();
|
||||
params.runner = *(ctx->runner());
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
|
||||
std::vector<Tensor> components;
|
||||
components.reserve(dataset->output_dtypes().size());
|
||||
bool end_of_sequence;
|
||||
std::vector<Tensor> components;
|
||||
components.reserve(dataset->output_dtypes().size());
|
||||
bool end_of_sequence;
|
||||
|
||||
OP_REQUIRES_OK(ctx,
|
||||
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
|
||||
OP_REQUIRES(ctx, !end_of_sequence,
|
||||
errors::InvalidArgument("Dataset was empty."));
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(ctx, !end_of_sequence,
|
||||
errors::InvalidArgument("Dataset was empty."), done);
|
||||
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
||||
ctx->set_output(i, components[i]);
|
||||
}
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
||||
ctx->set_output(i, components[i]);
|
||||
}
|
||||
|
||||
components.clear();
|
||||
OP_REQUIRES_OK(ctx,
|
||||
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
|
||||
OP_REQUIRES(ctx, end_of_sequence,
|
||||
errors::InvalidArgument("Dataset had more than one element."));
|
||||
components.clear();
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, end_of_sequence,
|
||||
errors::InvalidArgument("Dataset had more than one element."), done);
|
||||
|
||||
done();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
};
|
||||
|
||||
class OneShotIteratorOp : public AsyncOpKernel {
|
||||
|
Loading…
Reference in New Issue
Block a user