From 9158b1b83a0128fc41bfccd80fe26d8231fe958b Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 21 Aug 2018 11:48:37 -0700 Subject: [PATCH] [tf.data] Move captured function instantiation to iterator initialization time. Previously, a function instantiation error (e.g. in `Dataset.map()`) would lead to an error in each GetNext() call that attempted to use the function. Moving this to iterator instantiation time has the benefit that the error will be reported once when the initialization op is executed, which has a more helpful stack trace, since it should not be conflated with other potential op failures. PiperOrigin-RevId: 209633511 --- tensorflow/core/common_runtime/function.cc | 6 ++ tensorflow/core/framework/function.h | 5 ++ .../core/kernels/data/captured_function.cc | 65 ++++++++++--------- .../core/kernels/data/captured_function.h | 4 +- .../core/kernels/data/filter_dataset_op.cc | 4 +- .../core/kernels/data/flat_map_dataset_op.cc | 4 +- .../core/kernels/data/generator_dataset_op.cc | 20 +++--- .../core/kernels/data/generator_dataset_op.h | 1 - .../data/group_by_reducer_dataset_op.cc | 9 ++- .../data/group_by_window_dataset_op.cc | 8 ++- .../kernels/data/interleave_dataset_op.cc | 5 +- tensorflow/core/kernels/data/iterator_ops.cc | 19 ++++-- .../kernels/data/map_and_batch_dataset_op.cc | 4 +- .../core/kernels/data/map_dataset_op.cc | 4 +- .../core/kernels/data/optimize_dataset_op.cc | 11 +++- .../data/parallel_interleave_dataset_op.cc | 4 +- .../kernels/data/parallel_map_dataset_op.cc | 6 +- .../kernels/data/parallel_map_iterator.cc | 28 ++++++-- .../core/kernels/data/parallel_map_iterator.h | 10 ++- .../core/kernels/data/repeat_dataset_op.cc | 38 +++++++---- .../core/kernels/data/scan_dataset_op.cc | 4 +- .../data/kernel_tests/map_dataset_op_test.py | 32 +++++++++ 22 files changed, 210 insertions(+), 81 deletions(-) diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 54bbe84b57b..fb89bcc0df3 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -555,6 +555,12 @@ Status FunctionLibraryRuntimeImpl::Instantiate( next_handle_++; } } + + if (options.create_kernels_eagerly) { + Item* item; + TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item)); + } + return Status::OK(); } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index edb7ed01e91..a2e69a152a7 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -490,6 +490,11 @@ class FunctionLibraryRuntime { // Instantiates the function using an executor of the given type. If empty, // the default TensorFlow executor will be used. string executor_type; + + // If true, the runtime will attempt to create kernels for the function at + // instantiation time, rather than on the first run. This can be used to + // surface errors earlier. + bool create_kernels_eagerly = false; }; typedef uint64 Handle; virtual Status Instantiate(const string& function_name, AttrSlice attrs, diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index 82da3854056..abdf6ee4e83 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -172,31 +172,17 @@ class BorrowedArgsCallFrame : public CallFrameBase { } // namespace -Status CapturedFunction::MaybeInstantiate( - IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle) { - mutex_lock l(mu_); +Status CapturedFunction::GetHandle(IteratorContext* ctx, + FunctionLibraryRuntime::Handle* out_handle) { + tf_shared_lock l(mu_); if (lib_ == nullptr) { - // The context's runtime will be used for all subsequent calls. - lib_ = ctx->lib(); - DCHECK(f_handle_ == kInvalidHandle); - FunctionLibraryRuntime::InstantiateOptions inst_opts; - inst_opts.overlay_lib = ctx->function_library().get(); - inst_opts.state_handle = std::to_string(random::New64()); - TF_RETURN_IF_ERROR(lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), - inst_opts, &f_handle_)); - const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_); - if (fbody == nullptr) { - return errors::Internal("Failed to instantiate function body."); - } - ret_types_ = fbody->ret_types; - } else { - // TODO(mrry): Consider moving this under a shared lock, as it is - // the common case. - if (ctx->lib() != lib_) { - return errors::Internal( - "Captured function was called with a different " - "FunctionLibraryRuntime*, which is not permitted."); - } + return errors::Internal("Captured function \"", func_.name(), + "\" was called before it was instantiated."); + } + if (ctx->lib() != lib_) { + return errors::Internal("Captured function \"", func_.name(), + "\" was called with a different " + "FunctionLibraryRuntime*, which is not permitted."); } *out_handle = f_handle_; return Status::OK(); @@ -205,7 +191,7 @@ Status CapturedFunction::MaybeInstantiate( Status CapturedFunction::Run(IteratorContext* ctx, std::vector&& args, std::vector* rets) { FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle)); + TF_RETURN_IF_ERROR(GetHandle(ctx, &handle)); FunctionLibraryRuntime::Options f_opts; f_opts.step_id = CapturedFunction::generate_step_id(); @@ -242,7 +228,7 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx, const std::vector& args, std::vector* rets) { FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle)); + TF_RETURN_IF_ERROR(GetHandle(ctx, &handle)); FunctionLibraryRuntime::Options f_opts; f_opts.step_id = CapturedFunction::generate_step_id(); @@ -277,9 +263,30 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx, } Status CapturedFunction::Instantiate(IteratorContext* ctx) { - FunctionLibraryRuntime::Handle unused_handle; - TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &unused_handle)); mutex_lock l(mu_); + if (lib_ == nullptr) { + // The context's runtime will be used for all subsequent calls. + lib_ = ctx->lib(); + DCHECK(f_handle_ == kInvalidHandle); + FunctionLibraryRuntime::InstantiateOptions inst_opts; + inst_opts.overlay_lib = ctx->function_library().get(); + inst_opts.state_handle = std::to_string(random::New64()); + inst_opts.create_kernels_eagerly = true; + Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), + inst_opts, &f_handle_)); + TF_RETURN_IF_ERROR(s); + const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_); + if (fbody == nullptr) { + return errors::Internal("Failed to instantiate function body."); + } + ret_types_ = fbody->ret_types; + } else { + if (ctx->lib() != lib_) { + return errors::Internal( + "Captured function was called with a different " + "FunctionLibraryRuntime*, which is not permitted."); + } + } if (captured_runner_ == nullptr) { captured_runner_ = *ctx->runner(); } @@ -343,7 +350,7 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, // be deleted before `done` is called. Take care not to capture `ctx` in any // code that may execute asynchronously in this function. FunctionLibraryRuntime::Handle handle; - Status s = MaybeInstantiate(ctx, &handle); + Status s = GetHandle(ctx, &handle); if (!s.ok()) { done(s); return; diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index e9ad3e381d4..c95f2b1c017 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -116,8 +116,8 @@ class CapturedFunction { CapturedFunction(const NameAttrList& func, std::vector captured_inputs); - Status MaybeInstantiate(IteratorContext* ctx, - FunctionLibraryRuntime::Handle* out_handle); + Status GetHandle(IteratorContext* ctx, + FunctionLibraryRuntime::Handle* out_handle); mutex mu_; const NameAttrList func_; diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index a80e102ccfa..f5c7d336a66 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -149,7 +149,9 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index 07bcb9d4145..21e627a8e81 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -129,7 +129,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index 3c3d78b724e..ccee690d7e6 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { @@ -80,20 +81,20 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { } } + Status Initialize(IteratorContext* ctx) override { + TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR( + dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); + return Status::OK(); + } + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - if (!initialized_) { - TF_RETURN_IF_ERROR( - dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); - // Explicitly instantiate the finalize function here so that - // we can invoke it in the destructor. - TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx)); - initialized_ = true; - } - if (finalized_) { *end_of_sequence = true; return Status::OK(); @@ -121,7 +122,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { private: mutex mu_; - bool initialized_ GUARDED_BY(mu_) = false; bool finalized_ GUARDED_BY(mu_) = false; std::vector state_ GUARDED_BY(mu_); }; diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h index 3f84fa9c2ec..84075431365 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.h +++ b/tensorflow/core/kernels/data/generator_dataset_op.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ #include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/kernels/data/captured_function.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc index be4132a064b..4a388645f22 100644 --- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc @@ -190,7 +190,14 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR( + dataset()->captured_finalize_func_->Instantiate(ctx)); + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index 288695f3cdc..f993a689341 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -205,7 +205,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR( + dataset()->captured_window_size_func_->Instantiate(ctx)); + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 58b79d60266..6bba6677595 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -1,4 +1,3 @@ - /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -156,7 +155,9 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { args_list_(params.dataset->cycle_length_) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) { diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 61a6c06135e..25beb02f0e4 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -104,9 +104,8 @@ class IteratorResource : public ResourceBase { bool* end_of_sequence) { std::shared_ptr captured_iterator(iterator_); if (captured_iterator) { - if (lib_ != nullptr) { - ctx->set_lib(lib_); - } + CHECK_NOTNULL(lib_); + ctx->set_lib(lib_); return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence); } else { return errors::FailedPrecondition( @@ -162,8 +161,10 @@ class IteratorResource : public ResourceBase { TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); std::unique_ptr iterator; + IteratorContext iter_ctx(ctx); + iter_ctx.set_lib(lib); TF_RETURN_IF_ERROR( - dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator)); + dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator)); TF_RETURN_IF_ERROR(set_iterator(std::move(iterator))); std::shared_ptr captured_iterator(iterator_); @@ -198,6 +199,8 @@ class IteratorResource : public ResourceBase { return lib_def_; } + FunctionLibraryRuntime* function_library_runtime() { return lib_; } + // Transfers ownership of iterator to this. This method is thread-safe. Status set_iterator(std::unique_ptr iterator) { if (iterator) { @@ -612,8 +615,10 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) { core::ScopedUnref unref(iterator_resource); std::unique_ptr iterator; + IteratorContext iter_ctx(ctx); + iter_ctx.set_lib(iterator_resource->function_library_runtime()); OP_REQUIRES_OK( - ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator)); + ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator)); OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator))); } @@ -837,8 +842,10 @@ class OneShotIteratorOp : public AsyncOpKernel { DatasetBase* dataset; TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset)); std::unique_ptr iter; + IteratorContext iter_ctx(ctx); + iter_ctx.set_lib(lib); TF_RETURN_IF_ERROR( - dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iter)); + dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iter)); TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter))); (*iterator)->Ref(); diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 0e17011b051..c4df7f27567 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -204,7 +204,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 294fb1c49a1..26ae26a7fdf 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -127,7 +127,9 @@ class MapDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index b097598cd94..b2d307ba8a1 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -142,8 +142,15 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->optimized_input_->MakeIterator(ctx, prefix(), - &input_impl_); + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = *(ctx->runner()); + params.stats_aggregator_getter = ctx->stats_aggregator_getter(); + params.lib = ctx->lib(); + params.function_library = dataset()->flib_def_; + params.allocator_getter = ctx->allocator_getter(); + return dataset()->optimized_input_->MakeIterator( + IteratorContext(params), prefix(), &input_impl_); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index cfa96d910d3..bf86361a718 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -251,7 +251,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } // It is implemented so that it matches the deterministic interleave diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index a407abfce45..e03a4e353bf 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -88,6 +88,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { + auto init_func = [this](IteratorContext* ctx) { + return captured_func_->Instantiate(ctx); + }; + auto map_func = [this](IteratorContext* ctx, std::vector input_element, std::vector* result, StatusCallback done) { @@ -97,7 +101,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { return NewParallelMapIterator( {this, strings::StrCat(prefix, "::ParallelMap")}, input_, - std::move(map_func), num_parallel_calls_); + std::move(init_func), std::move(map_func), num_parallel_calls_); } const DataTypeVector& output_dtypes() const override { diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 4d32b719a42..61f8139b9e7 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -26,10 +26,12 @@ class ParallelMapIterator : public DatasetBaseIterator { public: explicit ParallelMapIterator( const typename DatasetBaseIterator::BaseParams& params, - const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, - int32 num_parallel_calls) + const DatasetBase* input_dataset, + std::function init_func, + ParallelMapIteratorFunction map_func, int32 num_parallel_calls) : DatasetBaseIterator(params), input_dataset_(input_dataset), + init_func_(std::move(init_func)), map_func_(std::move(map_func)), num_parallel_calls_(num_parallel_calls) {} @@ -50,7 +52,12 @@ class ParallelMapIterator : public DatasetBaseIterator { } Status Initialize(IteratorContext* ctx) override { - return input_dataset_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + input_dataset_->MakeIterator(ctx, prefix(), &input_impl_)); + if (init_func_) { + TF_RETURN_IF_ERROR(init_func_(ctx)); + } + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -285,6 +292,7 @@ class ParallelMapIterator : public DatasetBaseIterator { } const DatasetBase* const input_dataset_; // Not owned. + const std::function init_func_; const ParallelMapIteratorFunction map_func_; const int32 num_parallel_calls_; // Used for coordination between the main thread and the runner thread. @@ -311,8 +319,18 @@ std::unique_ptr NewParallelMapIterator( const DatasetBaseIterator::BaseParams& params, const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, int32 num_parallel_calls) { - return std::unique_ptr(new ParallelMapIterator( - params, input_dataset, std::move(map_func), num_parallel_calls)); + return NewParallelMapIterator(params, input_dataset, nullptr, + std::move(map_func), num_parallel_calls); +} + +std::unique_ptr NewParallelMapIterator( + const DatasetBaseIterator::BaseParams& params, + const DatasetBase* input_dataset, + std::function init_func, + ParallelMapIteratorFunction map_func, int32 num_parallel_calls) { + return std::unique_ptr( + new ParallelMapIterator(params, input_dataset, std::move(init_func), + std::move(map_func), num_parallel_calls)); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h index 2ce36c38690..7e6cc586f30 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.h +++ b/tensorflow/core/kernels/data/parallel_map_iterator.h @@ -33,7 +33,15 @@ using ParallelMapIteratorFunction = std::vector*, StatusCallback)>; // Returns a new iterator that applies `map_func` to the elements of -// `input_dataset` using the given degree of parallelism. +// `input_dataset` using the given degree of parallelism. `init_func` (if +// specified) will be executed when the iterator is initialized (see +// `IteratorBase::Initialize()`) and enables the user to specify error checking +// logic that can fail early. +std::unique_ptr NewParallelMapIterator( + const DatasetBaseIterator::BaseParams& params, + const DatasetBase* input_dataset, + std::function init_func, + ParallelMapIteratorFunction map_func, int32 num_parallel_calls); std::unique_ptr NewParallelMapIterator( const DatasetBaseIterator::BaseParams& params, const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 5e9ace3486e..299949b99f9 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -172,32 +172,39 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { class ForeverIterator : public DatasetIterator { public: explicit ForeverIterator(const Params& params) - : DatasetIterator(params), input_impl_(nullptr) {} + : DatasetIterator(params), + input_impl_(nullptr), + first_call_(true) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. do { - bool first_call = false; if (!input_impl_) { - first_call = true; TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); } - TF_RETURN_IF_ERROR( - input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); - if (!*end_of_sequence) { + Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + if (first_call_ && *end_of_sequence) { + // If the first call to GetNext() fails because the end + // of sequence has been reached, we terminate the + // iteration immediately. (Otherwise, this iterator + // would loop infinitely and never produce a value.) + input_impl_.reset(); return Status::OK(); + } + first_call_ = false; + if (!*end_of_sequence) { + return s; } else { input_impl_.reset(); - if (first_call) { - // If the first call to GetNext() fails because the end - // of sequence has been reached, we terminate the - // iteration immediately. (Otherwise, this iterator - // would loop infinitely and never produce a value.) - return Status::OK(); - } + first_call_ = true; } } while (true); } @@ -205,7 +212,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - if (input_impl_) + if (!first_call_) TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); else TF_RETURN_IF_ERROR( @@ -218,10 +225,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { mutex_lock l(mu_); if (reader->Contains(full_name("uninitialized"))) { input_impl_.reset(); + first_call_ = true; } else { TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + first_call_ = false; } return Status::OK(); } @@ -229,6 +238,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { private: mutex mu_; std::unique_ptr input_impl_ GUARDED_BY(mu_); + bool first_call_ GUARDED_BY(mu_); }; const int64 count_; diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index e4cb31e2b2e..5d3319b19fa 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -153,7 +153,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { state_(params.dataset->initial_state_) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 637bde9ae4e..52b4320bf1b 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -24,6 +24,7 @@ import warnings import numpy as np +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -31,6 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops @@ -673,6 +675,36 @@ class MapDatasetTest(test.TestCase): r"Dataset.map\(\): None."): _ = dataset.map(lambda x: None) + def testBrokenFunctionErrorOnInitialization(self): + dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0, 3.0]) + + def broken_function(_): + """A function deliberately designed to fail on instantiation.""" + value = [] + tensor_value = attr_value_pb2.AttrValue() + tensor_value.tensor.CopyFrom( + tensor_util.make_tensor_proto( + value, dtype=dtypes.float32, shape=[0], verify_shape=False)) + dtype_value = attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum) + + # Create a "Const" op with a `tf.float32` value and a `tf.int32` type + # attr. + const_tensor = ops.get_default_graph().create_op( + "Const", [], [dtypes.int32], + attrs={ + "value": tensor_value, + "dtype": dtype_value + }, + name="BrokenConst").outputs[0] + return const_tensor + + dataset = dataset.map(broken_function) + iterator = dataset.make_initializable_iterator() + + with self.test_session() as sess: + with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"): + sess.run(iterator.initializer) + class MapDatasetBenchmark(test.Benchmark):