[tf.data] Simplify error handling in async kernels that use a BackgroundWorker.

This CL moves the main logic of three async tf.data kernels to a separate DoCompute() method that returns a Status. This enables us to use RAII for cleaning up objects in error cases, without having to manipulate the `done` callback.

PiperOrigin-RevId: 282584164
Change-Id: I7c7a27c8f404826146adcf29c8b4a62912155be5
This commit is contained in:
Derek Murray 2019-11-26 09:44:49 -08:00 committed by TensorFlower Gardener
parent 1b5e0b75ce
commit 4855dd694f
2 changed files with 210 additions and 316 deletions

View File

@ -48,93 +48,64 @@ class ToTFRecordOp : public AsyncOpKernel {
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
// The call to `iterator->GetNext()` may block and depend on an inter-op
// thread pool thread, so we issue the call using a background thread.
background_worker_.Schedule(std::bind(
[this, ctx](std::function<void()>& done) {
tstring filename;
OP_REQUIRES_OK_ASYNC(
ctx, ParseScalarArgument<tstring>(ctx, "filename", &filename),
done);
tstring compression_type;
OP_REQUIRES_OK_ASYNC(ctx,
ParseScalarArgument<tstring>(
ctx, "compression_type", &compression_type),
done);
std::unique_ptr<WritableFile> file;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->env()->NewWritableFile(filename, &file), done);
auto writer = absl::make_unique<io::RecordWriter>(
file.get(), io::RecordWriterOptions::CreateRecordWriterOptions(
compression_type));
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
IteratorContext::Params params(ctx);
FunctionHandleCache function_handle_cache(params.flr);
params.function_handle_cache = &function_handle_cache;
ResourceMgr resource_mgr;
params.resource_mgr = &resource_mgr;
CancellationManager cancellation_manager;
params.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
OP_REQUIRES_OK_ASYNC(
ctx,
RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn),
done);
// Update the `done` callback to deregister the cancellation callback.
done = std::bind(
[](const std::function<void()>& done,
const std::function<void()>& deregister_fn) {
deregister_fn();
done();
},
std::move(done), std::move(deregister_fn));
IteratorContext iter_ctx(std::move(params));
std::unique_ptr<IteratorBase> iterator;
OP_REQUIRES_OK_ASYNC(
ctx,
dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator",
&iterator),
done);
// Update the `done` callback to destroy the iterator before calling
// the actual callback to avoid destruction races.
IteratorBase* raw_iterator = iterator.release();
done = std::bind(
[raw_iterator](const std::function<void()>& done) {
delete raw_iterator;
done();
},
std::move(done));
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence;
do {
OP_REQUIRES_OK_ASYNC(
ctx,
raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
done);
if (!end_of_sequence) {
OP_REQUIRES_OK_ASYNC(
ctx, writer->WriteRecord(components[0].scalar<tstring>()()),
done);
}
components.clear();
} while (!end_of_sequence);
done();
},
std::move(done)));
background_worker_.Schedule([this, ctx, done = std::move(done)]() {
OP_REQUIRES_OK_ASYNC(ctx, DoCompute(ctx), done);
done();
});
}
private:
Status DoCompute(OpKernelContext* ctx) {
tstring filename;
TF_RETURN_IF_ERROR(
ParseScalarArgument<tstring>(ctx, "filename", &filename));
tstring compression_type;
TF_RETURN_IF_ERROR(ParseScalarArgument<tstring>(ctx, "compression_type",
&compression_type));
std::unique_ptr<WritableFile> file;
TF_RETURN_IF_ERROR(ctx->env()->NewWritableFile(filename, &file));
auto writer = absl::make_unique<io::RecordWriter>(
file.get(),
io::RecordWriterOptions::CreateRecordWriterOptions(compression_type));
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
IteratorContext::Params params(ctx);
FunctionHandleCache function_handle_cache(params.flr);
params.function_handle_cache = &function_handle_cache;
ResourceMgr resource_mgr;
params.resource_mgr = &resource_mgr;
CancellationManager cancellation_manager;
params.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
IteratorContext iter_ctx(std::move(params));
std::unique_ptr<IteratorBase> iterator;
TF_RETURN_IF_ERROR(
dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator));
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence;
do {
TF_RETURN_IF_ERROR(
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
if (!end_of_sequence) {
TF_RETURN_IF_ERROR(
writer->WriteRecord(components[0].scalar<tstring>()()));
}
components.clear();
} while (!end_of_sequence);
return Status::OK();
}
BackgroundWorker background_worker_;
};

View File

@ -453,96 +453,60 @@ class ToSingleElementOp : public AsyncOpKernel {
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
background_worker_.Schedule(std::bind(
[ctx](std::function<void()>& done) {
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
IteratorContext::Params params(ctx);
FunctionHandleCache function_handle_cache(params.flr);
params.function_handle_cache = &function_handle_cache;
ResourceMgr resource_mgr;
params.resource_mgr = &resource_mgr;
CancellationManager cancellation_manager;
params.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
OP_REQUIRES_OK_ASYNC(
ctx,
RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn),
done);
// Update the `done` callback to deregister the cancellation callback.
done = std::bind(
[](const std::function<void()>& done,
const std::function<void()>& deregister_fn) {
deregister_fn();
done();
},
std::move(done), std::move(deregister_fn));
IteratorContext iter_ctx(std::move(params));
std::unique_ptr<IteratorBase> iterator;
OP_REQUIRES_OK_ASYNC(
ctx,
dataset->MakeIterator(&iter_ctx, "SingleElementIterator",
&iterator),
done);
// Update the `done` callback to destroy the iterator before calling
// the actual callback to avoid destruction races.
IteratorBase* raw_iterator = iterator.release();
done = std::bind(
[raw_iterator](const std::function<void()>& done) {
delete raw_iterator;
done();
},
std::move(done));
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence = false;
Status s =
raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
if (!s.ok()) {
ctx->SetStatus(s);
done();
return;
}
if (end_of_sequence) {
ctx->SetStatus(errors::InvalidArgument("Dataset was empty."));
done();
return;
}
for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
ctx->set_output(i, components[i]);
}
components.clear();
s.Update(
raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
if (!s.ok()) {
ctx->SetStatus(s);
done();
return;
}
if (!end_of_sequence) {
ctx->SetStatus(
errors::InvalidArgument("Dataset had more than one element."));
done();
return;
}
done();
},
std::move(done)));
background_worker_.Schedule([this, ctx, done = std::move(done)]() {
OP_REQUIRES_OK_ASYNC(ctx, DoCompute(ctx), done);
done();
});
}
private:
Status DoCompute(OpKernelContext* ctx) {
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
IteratorContext::Params params(ctx);
FunctionHandleCache function_handle_cache(params.flr);
params.function_handle_cache = &function_handle_cache;
ResourceMgr resource_mgr;
params.resource_mgr = &resource_mgr;
CancellationManager cancellation_manager;
params.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
IteratorContext iter_ctx(std::move(params));
std::unique_ptr<IteratorBase> iterator;
TF_RETURN_IF_ERROR(
dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator));
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence = false;
TF_RETURN_IF_ERROR(
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
if (end_of_sequence) {
return errors::InvalidArgument("Dataset was empty.");
}
for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
ctx->set_output(i, components[i]);
}
components.clear();
TF_RETURN_IF_ERROR(
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
if (!end_of_sequence) {
return errors::InvalidArgument("Dataset had more than one element.");
}
return Status::OK();
}
BackgroundWorker background_worker_;
};
@ -565,154 +529,113 @@ class ReduceDatasetOp : public AsyncOpKernel {
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
background_worker_.Schedule(std::bind(
[this, ctx](std::function<void()>& done) {
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
OpInputList inputs;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("initial_state", &inputs),
done);
std::vector<Tensor> state(inputs.begin(), inputs.end());
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK_ASYNC(
ctx,
CapturedFunction::Create(ctx, func_metadata_, "other_arguments",
&captured_func),
done);
IteratorContext::Params params(ctx);
auto function_handle_cache =
absl::make_unique<FunctionHandleCache>(params.flr);
params.function_handle_cache = function_handle_cache.get();
ResourceMgr resource_mgr;
params.resource_mgr = &resource_mgr;
CancellationManager cancellation_manager;
params.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
OP_REQUIRES_OK_ASYNC(
ctx,
RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn),
done);
// Update the `done` callback to deregister the cancellation callback.
done = std::bind(
[](const std::function<void()>& done,
const std::function<void()>& deregister_fn) {
deregister_fn();
done();
},
std::move(done), std::move(deregister_fn));
IteratorContext iter_ctx(std::move(params));
std::unique_ptr<InstantiatedCapturedFunction>
instantiated_captured_func;
OP_REQUIRES_OK_ASYNC(ctx,
captured_func->Instantiate(
&iter_ctx, &instantiated_captured_func),
done);
std::unique_ptr<IteratorBase> iterator;
OP_REQUIRES_OK_ASYNC(
ctx,
dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator),
done);
// Update the `done` callback to destroy the iterator before calling
// the actual callback to avoid destruction races.
IteratorBase* raw_iterator = iterator.release();
done = std::bind(
[raw_iterator](const std::function<void()>& done) {
delete raw_iterator;
done();
},
std::move(done));
// Iterate through the input dataset.
Status status;
while (true) {
OP_REQUIRES_ASYNC(ctx, !ctx->cancellation_manager()->IsCancelled(),
errors::Cancelled("Operation was cancelled"),
done);
std::vector<Tensor> next_input_element;
bool end_of_input;
status = raw_iterator->GetNext(&iter_ctx, &next_input_element,
&end_of_input);
if (!status.ok() || end_of_input) {
break;
}
// Run the reduce function to update the current state.
std::vector<Tensor> args;
args.reserve(state.size() + next_input_element.size());
std::copy(state.begin(), state.end(), std::back_inserter(args));
std::copy(next_input_element.begin(), next_input_element.end(),
std::back_inserter(args));
std::vector<Tensor> reduce_func_output;
status = instantiated_captured_func->Run(&iter_ctx, std::move(args),
&reduce_func_output);
if (!status.ok()) {
break;
}
OP_REQUIRES_ASYNC(
ctx, reduce_func_output.size() == state.size(),
errors::InvalidArgument(
"The number of components of the initial state and the "
"reduce "
"function output does not match. (initial_state=",
state.size(), ", output=", reduce_func_output.size(), ")."),
done);
std::swap(reduce_func_output, state);
}
if (!status.ok()) {
ctx->SetStatus(status);
done();
return;
}
OP_REQUIRES_ASYNC(ctx, state.size() == output_types_.size(),
errors::InvalidArgument(
"The number of result elements does not match "
"the size of output types: ",
state.size(), " vs. ", output_types_.size()),
done);
OP_REQUIRES_ASYNC(ctx, state.size() == output_shapes_.size(),
errors::InvalidArgument(
"The number of result elements does not match "
"the size of output shapes: ",
state.size(), " vs. ", output_shapes_.size()),
done);
for (int i = 0; i < state.size(); ++i) {
OP_REQUIRES_ASYNC(
ctx, state[i].dtype() == output_types_[i],
errors::InvalidArgument(
"The result does not match the expected type for "
"component ",
i, ". Expected: ", DataTypeString(output_types_[i]),
". Actual: ", DataTypeString(state[i].dtype()), "."),
done);
OP_REQUIRES_ASYNC(
ctx, output_shapes_[i].IsCompatibleWith(state[i].shape()),
errors::InvalidArgument(
"The result does not match the expected shape for "
"component ",
i, ". Expected: ", output_shapes_[i].DebugString(),
". Actual: ", state[i].shape().DebugString(), "."),
done);
ctx->set_output(i, state[i]);
}
done();
},
std::move(done)));
background_worker_.Schedule([this, ctx, done = std::move(done)]() {
OP_REQUIRES_OK_ASYNC(ctx, DoCompute(ctx), done);
done();
});
}
private:
Status DoCompute(OpKernelContext* ctx) {
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
OpInputList inputs;
TF_RETURN_IF_ERROR(ctx->input_list("initial_state", &inputs));
std::vector<Tensor> state(inputs.begin(), inputs.end());
std::unique_ptr<CapturedFunction> captured_func;
TF_RETURN_IF_ERROR(CapturedFunction::Create(
ctx, func_metadata_, "other_arguments", &captured_func));
IteratorContext::Params params(ctx);
auto function_handle_cache =
absl::make_unique<FunctionHandleCache>(params.flr);
params.function_handle_cache = function_handle_cache.get();
ResourceMgr resource_mgr;
params.resource_mgr = &resource_mgr;
CancellationManager cancellation_manager;
params.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
IteratorContext iter_ctx(std::move(params));
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func;
TF_RETURN_IF_ERROR(
captured_func->Instantiate(&iter_ctx, &instantiated_captured_func));
std::unique_ptr<IteratorBase> iterator;
TF_RETURN_IF_ERROR(
dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator));
// Iterate through the input dataset.
while (true) {
if (ctx->cancellation_manager()->IsCancelled()) {
return errors::Cancelled("Operation was cancelled");
}
std::vector<Tensor> next_input_element;
bool end_of_input;
TF_RETURN_IF_ERROR(
iterator->GetNext(&iter_ctx, &next_input_element, &end_of_input));
if (end_of_input) {
break;
}
// Run the reduce function to update the current state.
std::vector<Tensor> args;
args.reserve(state.size() + next_input_element.size());
std::copy(state.begin(), state.end(), std::back_inserter(args));
std::copy(next_input_element.begin(), next_input_element.end(),
std::back_inserter(args));
std::vector<Tensor> reduce_func_output;
TF_RETURN_IF_ERROR(instantiated_captured_func->Run(
&iter_ctx, std::move(args), &reduce_func_output));
if (reduce_func_output.size() != state.size()) {
return errors::InvalidArgument(
"The number of components of the initial state and the "
"reduce "
"function output does not match. (initial_state=",
state.size(), ", output=", reduce_func_output.size(), ").");
}
std::swap(reduce_func_output, state);
}
if (state.size() != output_types_.size()) {
return errors::InvalidArgument(
"The number of result elements does not match "
"the size of output types: ",
state.size(), " vs. ", output_types_.size());
}
if (state.size() != output_shapes_.size()) {
return errors::InvalidArgument(
"The number of result elements does not match "
"the size of output shapes: ",
state.size(), " vs. ", output_shapes_.size());
}
for (size_t i = 0; i < state.size(); ++i) {
if (state[i].dtype() != output_types_[i]) {
return errors::InvalidArgument(
"The result does not match the expected type for "
"component ",
i, ". Expected: ", DataTypeString(output_types_[i]),
". Actual: ", DataTypeString(state[i].dtype()), ".");
}
if (!output_shapes_[i].IsCompatibleWith(state[i].shape())) {
return errors::InvalidArgument(
"The result does not match the expected shape for "
"component ",
i, ". Expected: ", output_shapes_[i].DebugString(),
". Actual: ", state[i].shape().DebugString(), ".");
}
ctx->set_output(i, state[i]);
}
return Status::OK();
}
std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;