From 4855dd694fa3f72b04d6798e1d7421fc2446da07 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 26 Nov 2019 09:44:49 -0800 Subject: [PATCH] [tf.data] Simplify error handling in async kernels that use a BackgroundWorker. This CL moves the main logic of three async tf.data kernels to a separate DoCompute() method that returns a Status. This enables us to use RAII for cleaning up objects in error cases, without having to manipulate the `done` callback. PiperOrigin-RevId: 282584164 Change-Id: I7c7a27c8f404826146adcf29c8b4a62912155be5 --- .../data/experimental/to_tf_record_op.cc | 139 +++---- tensorflow/core/kernels/data/iterator_ops.cc | 387 +++++++----------- 2 files changed, 210 insertions(+), 316 deletions(-) diff --git a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc index 9103880fc41..d01ecbd2930 100644 --- a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc +++ b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc @@ -48,93 +48,64 @@ class ToTFRecordOp : public AsyncOpKernel { void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { // The call to `iterator->GetNext()` may block and depend on an inter-op // thread pool thread, so we issue the call using a background thread. - background_worker_.Schedule(std::bind( - [this, ctx](std::function& done) { - tstring filename; - OP_REQUIRES_OK_ASYNC( - ctx, ParseScalarArgument(ctx, "filename", &filename), - done); - tstring compression_type; - OP_REQUIRES_OK_ASYNC(ctx, - ParseScalarArgument( - ctx, "compression_type", &compression_type), - done); - std::unique_ptr file; - OP_REQUIRES_OK_ASYNC( - ctx, ctx->env()->NewWritableFile(filename, &file), done); - auto writer = absl::make_unique( - file.get(), io::RecordWriterOptions::CreateRecordWriterOptions( - compression_type)); - - DatasetBase* dataset; - OP_REQUIRES_OK_ASYNC( - ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done); - - IteratorContext::Params params(ctx); - FunctionHandleCache function_handle_cache(params.flr); - params.function_handle_cache = &function_handle_cache; - ResourceMgr resource_mgr; - params.resource_mgr = &resource_mgr; - CancellationManager cancellation_manager; - params.cancellation_manager = &cancellation_manager; - std::function deregister_fn; - OP_REQUIRES_OK_ASYNC( - ctx, - RegisterCancellationCallback( - ctx->cancellation_manager(), - [cm = params.cancellation_manager]() { cm->StartCancel(); }, - &deregister_fn), - done); - - // Update the `done` callback to deregister the cancellation callback. - done = std::bind( - [](const std::function& done, - const std::function& deregister_fn) { - deregister_fn(); - done(); - }, - std::move(done), std::move(deregister_fn)); - - IteratorContext iter_ctx(std::move(params)); - std::unique_ptr iterator; - OP_REQUIRES_OK_ASYNC( - ctx, - dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", - &iterator), - done); - - // Update the `done` callback to destroy the iterator before calling - // the actual callback to avoid destruction races. - IteratorBase* raw_iterator = iterator.release(); - done = std::bind( - [raw_iterator](const std::function& done) { - delete raw_iterator; - done(); - }, - std::move(done)); - - std::vector components; - components.reserve(dataset->output_dtypes().size()); - bool end_of_sequence; - do { - OP_REQUIRES_OK_ASYNC( - ctx, - raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence), - done); - - if (!end_of_sequence) { - OP_REQUIRES_OK_ASYNC( - ctx, writer->WriteRecord(components[0].scalar()()), - done); - } - components.clear(); - } while (!end_of_sequence); - done(); - }, - std::move(done))); + background_worker_.Schedule([this, ctx, done = std::move(done)]() { + OP_REQUIRES_OK_ASYNC(ctx, DoCompute(ctx), done); + done(); + }); } private: + Status DoCompute(OpKernelContext* ctx) { + tstring filename; + TF_RETURN_IF_ERROR( + ParseScalarArgument(ctx, "filename", &filename)); + tstring compression_type; + TF_RETURN_IF_ERROR(ParseScalarArgument(ctx, "compression_type", + &compression_type)); + std::unique_ptr file; + TF_RETURN_IF_ERROR(ctx->env()->NewWritableFile(filename, &file)); + auto writer = absl::make_unique( + file.get(), + io::RecordWriterOptions::CreateRecordWriterOptions(compression_type)); + + DatasetBase* dataset; + TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset)); + + IteratorContext::Params params(ctx); + FunctionHandleCache function_handle_cache(params.flr); + params.function_handle_cache = &function_handle_cache; + ResourceMgr resource_mgr; + params.resource_mgr = &resource_mgr; + CancellationManager cancellation_manager; + params.cancellation_manager = &cancellation_manager; + std::function deregister_fn; + TF_RETURN_IF_ERROR(RegisterCancellationCallback( + ctx->cancellation_manager(), + [cm = params.cancellation_manager]() { cm->StartCancel(); }, + &deregister_fn)); + auto cleanup = gtl::MakeCleanup(std::move(deregister_fn)); + + IteratorContext iter_ctx(std::move(params)); + std::unique_ptr iterator; + TF_RETURN_IF_ERROR( + dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator)); + + std::vector components; + components.reserve(dataset->output_dtypes().size()); + bool end_of_sequence; + do { + TF_RETURN_IF_ERROR( + iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); + + if (!end_of_sequence) { + TF_RETURN_IF_ERROR( + writer->WriteRecord(components[0].scalar()())); + } + components.clear(); + } while (!end_of_sequence); + return Status::OK(); + } + BackgroundWorker background_worker_; }; diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index adc5c56d61c..49046cfd188 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -453,96 +453,60 @@ class ToSingleElementOp : public AsyncOpKernel { // The call to `iterator->GetNext()` may block and depend on an // inter-op thread pool thread, so we issue the call from the // owned thread pool. - background_worker_.Schedule(std::bind( - [ctx](std::function& done) { - DatasetBase* dataset; - OP_REQUIRES_OK_ASYNC( - ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done); - - IteratorContext::Params params(ctx); - FunctionHandleCache function_handle_cache(params.flr); - params.function_handle_cache = &function_handle_cache; - ResourceMgr resource_mgr; - params.resource_mgr = &resource_mgr; - CancellationManager cancellation_manager; - params.cancellation_manager = &cancellation_manager; - std::function deregister_fn; - OP_REQUIRES_OK_ASYNC( - ctx, - RegisterCancellationCallback( - ctx->cancellation_manager(), - [cm = params.cancellation_manager]() { cm->StartCancel(); }, - &deregister_fn), - done); - - // Update the `done` callback to deregister the cancellation callback. - done = std::bind( - [](const std::function& done, - const std::function& deregister_fn) { - deregister_fn(); - done(); - }, - std::move(done), std::move(deregister_fn)); - - IteratorContext iter_ctx(std::move(params)); - std::unique_ptr iterator; - OP_REQUIRES_OK_ASYNC( - ctx, - dataset->MakeIterator(&iter_ctx, "SingleElementIterator", - &iterator), - done); - - // Update the `done` callback to destroy the iterator before calling - // the actual callback to avoid destruction races. - IteratorBase* raw_iterator = iterator.release(); - done = std::bind( - [raw_iterator](const std::function& done) { - delete raw_iterator; - done(); - }, - std::move(done)); - - std::vector components; - components.reserve(dataset->output_dtypes().size()); - bool end_of_sequence = false; - - Status s = - raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence); - if (!s.ok()) { - ctx->SetStatus(s); - done(); - return; - } - if (end_of_sequence) { - ctx->SetStatus(errors::InvalidArgument("Dataset was empty.")); - done(); - return; - } - for (int i = 0; i < components.size(); ++i) { - // TODO(mrry): Check that the shapes match the shape attrs. - ctx->set_output(i, components[i]); - } - - components.clear(); - s.Update( - raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); - if (!s.ok()) { - ctx->SetStatus(s); - done(); - return; - } - if (!end_of_sequence) { - ctx->SetStatus( - errors::InvalidArgument("Dataset had more than one element.")); - done(); - return; - } - done(); - }, - std::move(done))); + background_worker_.Schedule([this, ctx, done = std::move(done)]() { + OP_REQUIRES_OK_ASYNC(ctx, DoCompute(ctx), done); + done(); + }); } private: + Status DoCompute(OpKernelContext* ctx) { + DatasetBase* dataset; + TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset)); + + IteratorContext::Params params(ctx); + FunctionHandleCache function_handle_cache(params.flr); + params.function_handle_cache = &function_handle_cache; + ResourceMgr resource_mgr; + params.resource_mgr = &resource_mgr; + CancellationManager cancellation_manager; + params.cancellation_manager = &cancellation_manager; + std::function deregister_fn; + TF_RETURN_IF_ERROR(RegisterCancellationCallback( + ctx->cancellation_manager(), + [cm = params.cancellation_manager]() { cm->StartCancel(); }, + &deregister_fn)); + auto cleanup = gtl::MakeCleanup(std::move(deregister_fn)); + + IteratorContext iter_ctx(std::move(params)); + std::unique_ptr iterator; + TF_RETURN_IF_ERROR( + dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator)); + + std::vector components; + components.reserve(dataset->output_dtypes().size()); + bool end_of_sequence = false; + + TF_RETURN_IF_ERROR( + iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); + + if (end_of_sequence) { + return errors::InvalidArgument("Dataset was empty."); + } + for (int i = 0; i < components.size(); ++i) { + // TODO(mrry): Check that the shapes match the shape attrs. + ctx->set_output(i, components[i]); + } + + components.clear(); + TF_RETURN_IF_ERROR( + iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); + if (!end_of_sequence) { + return errors::InvalidArgument("Dataset had more than one element."); + } + return Status::OK(); + } + BackgroundWorker background_worker_; }; @@ -565,154 +529,113 @@ class ReduceDatasetOp : public AsyncOpKernel { // The call to `iterator->GetNext()` may block and depend on an // inter-op thread pool thread, so we issue the call from the // owned thread pool. - background_worker_.Schedule(std::bind( - [this, ctx](std::function& done) { - DatasetBase* dataset; - OP_REQUIRES_OK_ASYNC( - ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done); - OpInputList inputs; - OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("initial_state", &inputs), - done); - std::vector state(inputs.begin(), inputs.end()); - - std::unique_ptr captured_func; - OP_REQUIRES_OK_ASYNC( - ctx, - CapturedFunction::Create(ctx, func_metadata_, "other_arguments", - &captured_func), - done); - - IteratorContext::Params params(ctx); - auto function_handle_cache = - absl::make_unique(params.flr); - params.function_handle_cache = function_handle_cache.get(); - ResourceMgr resource_mgr; - params.resource_mgr = &resource_mgr; - CancellationManager cancellation_manager; - params.cancellation_manager = &cancellation_manager; - std::function deregister_fn; - OP_REQUIRES_OK_ASYNC( - ctx, - RegisterCancellationCallback( - ctx->cancellation_manager(), - [cm = params.cancellation_manager]() { cm->StartCancel(); }, - &deregister_fn), - done); - - // Update the `done` callback to deregister the cancellation callback. - done = std::bind( - [](const std::function& done, - const std::function& deregister_fn) { - deregister_fn(); - done(); - }, - std::move(done), std::move(deregister_fn)); - - IteratorContext iter_ctx(std::move(params)); - std::unique_ptr - instantiated_captured_func; - OP_REQUIRES_OK_ASYNC(ctx, - captured_func->Instantiate( - &iter_ctx, &instantiated_captured_func), - done); - - std::unique_ptr iterator; - OP_REQUIRES_OK_ASYNC( - ctx, - dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator), - done); - - // Update the `done` callback to destroy the iterator before calling - // the actual callback to avoid destruction races. - IteratorBase* raw_iterator = iterator.release(); - done = std::bind( - [raw_iterator](const std::function& done) { - delete raw_iterator; - done(); - }, - std::move(done)); - - // Iterate through the input dataset. - Status status; - while (true) { - OP_REQUIRES_ASYNC(ctx, !ctx->cancellation_manager()->IsCancelled(), - errors::Cancelled("Operation was cancelled"), - done); - std::vector next_input_element; - bool end_of_input; - status = raw_iterator->GetNext(&iter_ctx, &next_input_element, - &end_of_input); - if (!status.ok() || end_of_input) { - break; - } - - // Run the reduce function to update the current state. - std::vector args; - args.reserve(state.size() + next_input_element.size()); - std::copy(state.begin(), state.end(), std::back_inserter(args)); - std::copy(next_input_element.begin(), next_input_element.end(), - std::back_inserter(args)); - - std::vector reduce_func_output; - status = instantiated_captured_func->Run(&iter_ctx, std::move(args), - &reduce_func_output); - if (!status.ok()) { - break; - } - OP_REQUIRES_ASYNC( - ctx, reduce_func_output.size() == state.size(), - errors::InvalidArgument( - "The number of components of the initial state and the " - "reduce " - "function output does not match. (initial_state=", - state.size(), ", output=", reduce_func_output.size(), ")."), - done); - std::swap(reduce_func_output, state); - } - - if (!status.ok()) { - ctx->SetStatus(status); - done(); - return; - } - - OP_REQUIRES_ASYNC(ctx, state.size() == output_types_.size(), - errors::InvalidArgument( - "The number of result elements does not match " - "the size of output types: ", - state.size(), " vs. ", output_types_.size()), - done); - OP_REQUIRES_ASYNC(ctx, state.size() == output_shapes_.size(), - errors::InvalidArgument( - "The number of result elements does not match " - "the size of output shapes: ", - state.size(), " vs. ", output_shapes_.size()), - done); - for (int i = 0; i < state.size(); ++i) { - OP_REQUIRES_ASYNC( - ctx, state[i].dtype() == output_types_[i], - errors::InvalidArgument( - "The result does not match the expected type for " - "component ", - i, ". Expected: ", DataTypeString(output_types_[i]), - ". Actual: ", DataTypeString(state[i].dtype()), "."), - done); - OP_REQUIRES_ASYNC( - ctx, output_shapes_[i].IsCompatibleWith(state[i].shape()), - errors::InvalidArgument( - "The result does not match the expected shape for " - "component ", - i, ". Expected: ", output_shapes_[i].DebugString(), - ". Actual: ", state[i].shape().DebugString(), "."), - done); - ctx->set_output(i, state[i]); - } - done(); - }, - std::move(done))); + background_worker_.Schedule([this, ctx, done = std::move(done)]() { + OP_REQUIRES_OK_ASYNC(ctx, DoCompute(ctx), done); + done(); + }); } private: + Status DoCompute(OpKernelContext* ctx) { + DatasetBase* dataset; + TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset)); + OpInputList inputs; + TF_RETURN_IF_ERROR(ctx->input_list("initial_state", &inputs)); + std::vector state(inputs.begin(), inputs.end()); + + std::unique_ptr captured_func; + TF_RETURN_IF_ERROR(CapturedFunction::Create( + ctx, func_metadata_, "other_arguments", &captured_func)); + + IteratorContext::Params params(ctx); + auto function_handle_cache = + absl::make_unique(params.flr); + params.function_handle_cache = function_handle_cache.get(); + ResourceMgr resource_mgr; + params.resource_mgr = &resource_mgr; + CancellationManager cancellation_manager; + params.cancellation_manager = &cancellation_manager; + std::function deregister_fn; + TF_RETURN_IF_ERROR(RegisterCancellationCallback( + ctx->cancellation_manager(), + [cm = params.cancellation_manager]() { cm->StartCancel(); }, + &deregister_fn)); + auto cleanup = gtl::MakeCleanup(std::move(deregister_fn)); + + IteratorContext iter_ctx(std::move(params)); + std::unique_ptr instantiated_captured_func; + TF_RETURN_IF_ERROR( + captured_func->Instantiate(&iter_ctx, &instantiated_captured_func)); + + std::unique_ptr iterator; + TF_RETURN_IF_ERROR( + dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator)); + + // Iterate through the input dataset. + while (true) { + if (ctx->cancellation_manager()->IsCancelled()) { + return errors::Cancelled("Operation was cancelled"); + } + std::vector next_input_element; + bool end_of_input; + TF_RETURN_IF_ERROR( + iterator->GetNext(&iter_ctx, &next_input_element, &end_of_input)); + if (end_of_input) { + break; + } + + // Run the reduce function to update the current state. + std::vector args; + args.reserve(state.size() + next_input_element.size()); + std::copy(state.begin(), state.end(), std::back_inserter(args)); + std::copy(next_input_element.begin(), next_input_element.end(), + std::back_inserter(args)); + + std::vector reduce_func_output; + TF_RETURN_IF_ERROR(instantiated_captured_func->Run( + &iter_ctx, std::move(args), &reduce_func_output)); + if (reduce_func_output.size() != state.size()) { + return errors::InvalidArgument( + "The number of components of the initial state and the " + "reduce " + "function output does not match. (initial_state=", + state.size(), ", output=", reduce_func_output.size(), ")."); + } + std::swap(reduce_func_output, state); + } + + if (state.size() != output_types_.size()) { + return errors::InvalidArgument( + "The number of result elements does not match " + "the size of output types: ", + state.size(), " vs. ", output_types_.size()); + } + if (state.size() != output_shapes_.size()) { + return errors::InvalidArgument( + "The number of result elements does not match " + "the size of output shapes: ", + state.size(), " vs. ", output_shapes_.size()); + } + for (size_t i = 0; i < state.size(); ++i) { + if (state[i].dtype() != output_types_[i]) { + return errors::InvalidArgument( + "The result does not match the expected type for " + "component ", + i, ". Expected: ", DataTypeString(output_types_[i]), + ". Actual: ", DataTypeString(state[i].dtype()), "."); + } + if (!output_shapes_[i].IsCompatibleWith(state[i].shape())) { + return errors::InvalidArgument( + "The result does not match the expected shape for " + "component ", + i, ". Expected: ", output_shapes_[i].DebugString(), + ". Actual: ", state[i].shape().DebugString(), "."); + } + ctx->set_output(i, state[i]); + } + return Status::OK(); + } + std::shared_ptr func_metadata_ = nullptr; DataTypeVector output_types_; std::vector output_shapes_;