diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 54bbe84b57b..fb89bcc0df3 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -555,6 +555,12 @@ Status FunctionLibraryRuntimeImpl::Instantiate( next_handle_++; } } + + if (options.create_kernels_eagerly) { + Item* item; + TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item)); + } + return Status::OK(); } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index edb7ed01e91..a2e69a152a7 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -490,6 +490,11 @@ class FunctionLibraryRuntime { // Instantiates the function using an executor of the given type. If empty, // the default TensorFlow executor will be used. string executor_type; + + // If true, the runtime will attempt to create kernels for the function at + // instantiation time, rather than on the first run. This can be used to + // surface errors earlier. + bool create_kernels_eagerly = false; }; typedef uint64 Handle; virtual Status Instantiate(const string& function_name, AttrSlice attrs, diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index 82da3854056..abdf6ee4e83 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -172,31 +172,17 @@ class BorrowedArgsCallFrame : public CallFrameBase { } // namespace -Status CapturedFunction::MaybeInstantiate( - IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle) { - mutex_lock l(mu_); +Status CapturedFunction::GetHandle(IteratorContext* ctx, + FunctionLibraryRuntime::Handle* out_handle) { + tf_shared_lock l(mu_); if (lib_ == nullptr) { - // The context's runtime will be used for all subsequent calls. - lib_ = ctx->lib(); - DCHECK(f_handle_ == kInvalidHandle); - FunctionLibraryRuntime::InstantiateOptions inst_opts; - inst_opts.overlay_lib = ctx->function_library().get(); - inst_opts.state_handle = std::to_string(random::New64()); - TF_RETURN_IF_ERROR(lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), - inst_opts, &f_handle_)); - const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_); - if (fbody == nullptr) { - return errors::Internal("Failed to instantiate function body."); - } - ret_types_ = fbody->ret_types; - } else { - // TODO(mrry): Consider moving this under a shared lock, as it is - // the common case. - if (ctx->lib() != lib_) { - return errors::Internal( - "Captured function was called with a different " - "FunctionLibraryRuntime*, which is not permitted."); - } + return errors::Internal("Captured function \"", func_.name(), + "\" was called before it was instantiated."); + } + if (ctx->lib() != lib_) { + return errors::Internal("Captured function \"", func_.name(), + "\" was called with a different " + "FunctionLibraryRuntime*, which is not permitted."); } *out_handle = f_handle_; return Status::OK(); @@ -205,7 +191,7 @@ Status CapturedFunction::MaybeInstantiate( Status CapturedFunction::Run(IteratorContext* ctx, std::vector&& args, std::vector* rets) { FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle)); + TF_RETURN_IF_ERROR(GetHandle(ctx, &handle)); FunctionLibraryRuntime::Options f_opts; f_opts.step_id = CapturedFunction::generate_step_id(); @@ -242,7 +228,7 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx, const std::vector& args, std::vector* rets) { FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle)); + TF_RETURN_IF_ERROR(GetHandle(ctx, &handle)); FunctionLibraryRuntime::Options f_opts; f_opts.step_id = CapturedFunction::generate_step_id(); @@ -277,9 +263,30 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx, } Status CapturedFunction::Instantiate(IteratorContext* ctx) { - FunctionLibraryRuntime::Handle unused_handle; - TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &unused_handle)); mutex_lock l(mu_); + if (lib_ == nullptr) { + // The context's runtime will be used for all subsequent calls. + lib_ = ctx->lib(); + DCHECK(f_handle_ == kInvalidHandle); + FunctionLibraryRuntime::InstantiateOptions inst_opts; + inst_opts.overlay_lib = ctx->function_library().get(); + inst_opts.state_handle = std::to_string(random::New64()); + inst_opts.create_kernels_eagerly = true; + Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), + inst_opts, &f_handle_)); + TF_RETURN_IF_ERROR(s); + const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_); + if (fbody == nullptr) { + return errors::Internal("Failed to instantiate function body."); + } + ret_types_ = fbody->ret_types; + } else { + if (ctx->lib() != lib_) { + return errors::Internal( + "Captured function was called with a different " + "FunctionLibraryRuntime*, which is not permitted."); + } + } if (captured_runner_ == nullptr) { captured_runner_ = *ctx->runner(); } @@ -343,7 +350,7 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, // be deleted before `done` is called. Take care not to capture `ctx` in any // code that may execute asynchronously in this function. FunctionLibraryRuntime::Handle handle; - Status s = MaybeInstantiate(ctx, &handle); + Status s = GetHandle(ctx, &handle); if (!s.ok()) { done(s); return; diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index e9ad3e381d4..c95f2b1c017 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -116,8 +116,8 @@ class CapturedFunction { CapturedFunction(const NameAttrList& func, std::vector captured_inputs); - Status MaybeInstantiate(IteratorContext* ctx, - FunctionLibraryRuntime::Handle* out_handle); + Status GetHandle(IteratorContext* ctx, + FunctionLibraryRuntime::Handle* out_handle); mutex mu_; const NameAttrList func_; diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index a80e102ccfa..f5c7d336a66 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -149,7 +149,9 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index 07bcb9d4145..21e627a8e81 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -129,7 +129,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index 3c3d78b724e..ccee690d7e6 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { @@ -80,20 +81,20 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { } } + Status Initialize(IteratorContext* ctx) override { + TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR( + dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); + return Status::OK(); + } + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - if (!initialized_) { - TF_RETURN_IF_ERROR( - dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); - // Explicitly instantiate the finalize function here so that - // we can invoke it in the destructor. - TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx)); - initialized_ = true; - } - if (finalized_) { *end_of_sequence = true; return Status::OK(); @@ -121,7 +122,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { private: mutex mu_; - bool initialized_ GUARDED_BY(mu_) = false; bool finalized_ GUARDED_BY(mu_) = false; std::vector state_ GUARDED_BY(mu_); }; diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h index 3f84fa9c2ec..84075431365 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.h +++ b/tensorflow/core/kernels/data/generator_dataset_op.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ #include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/kernels/data/captured_function.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc index be4132a064b..4a388645f22 100644 --- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc @@ -190,7 +190,14 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR( + dataset()->captured_finalize_func_->Instantiate(ctx)); + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index 288695f3cdc..f993a689341 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -205,7 +205,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx)); + TF_RETURN_IF_ERROR( + dataset()->captured_window_size_func_->Instantiate(ctx)); + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 58b79d60266..6bba6677595 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -1,4 +1,3 @@ - /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -156,7 +155,9 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { args_list_(params.dataset->cycle_length_) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) { diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 61a6c06135e..25beb02f0e4 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -104,9 +104,8 @@ class IteratorResource : public ResourceBase { bool* end_of_sequence) { std::shared_ptr captured_iterator(iterator_); if (captured_iterator) { - if (lib_ != nullptr) { - ctx->set_lib(lib_); - } + CHECK_NOTNULL(lib_); + ctx->set_lib(lib_); return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence); } else { return errors::FailedPrecondition( @@ -162,8 +161,10 @@ class IteratorResource : public ResourceBase { TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); std::unique_ptr iterator; + IteratorContext iter_ctx(ctx); + iter_ctx.set_lib(lib); TF_RETURN_IF_ERROR( - dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator)); + dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator)); TF_RETURN_IF_ERROR(set_iterator(std::move(iterator))); std::shared_ptr captured_iterator(iterator_); @@ -198,6 +199,8 @@ class IteratorResource : public ResourceBase { return lib_def_; } + FunctionLibraryRuntime* function_library_runtime() { return lib_; } + // Transfers ownership of iterator to this. This method is thread-safe. Status set_iterator(std::unique_ptr iterator) { if (iterator) { @@ -612,8 +615,10 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) { core::ScopedUnref unref(iterator_resource); std::unique_ptr iterator; + IteratorContext iter_ctx(ctx); + iter_ctx.set_lib(iterator_resource->function_library_runtime()); OP_REQUIRES_OK( - ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator)); + ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator)); OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator))); } @@ -837,8 +842,10 @@ class OneShotIteratorOp : public AsyncOpKernel { DatasetBase* dataset; TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset)); std::unique_ptr iter; + IteratorContext iter_ctx(ctx); + iter_ctx.set_lib(lib); TF_RETURN_IF_ERROR( - dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iter)); + dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iter)); TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter))); (*iterator)->Ref(); diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 0e17011b051..c4df7f27567 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -204,7 +204,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 294fb1c49a1..26ae26a7fdf 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -127,7 +127,9 @@ class MapDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index b097598cd94..b2d307ba8a1 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -142,8 +142,15 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->optimized_input_->MakeIterator(ctx, prefix(), - &input_impl_); + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = *(ctx->runner()); + params.stats_aggregator_getter = ctx->stats_aggregator_getter(); + params.lib = ctx->lib(); + params.function_library = dataset()->flib_def_; + params.allocator_getter = ctx->allocator_getter(); + return dataset()->optimized_input_->MakeIterator( + IteratorContext(params), prefix(), &input_impl_); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index cfa96d910d3..bf86361a718 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -251,7 +251,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } // It is implemented so that it matches the deterministic interleave diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index a407abfce45..e03a4e353bf 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -88,6 +88,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { + auto init_func = [this](IteratorContext* ctx) { + return captured_func_->Instantiate(ctx); + }; + auto map_func = [this](IteratorContext* ctx, std::vector input_element, std::vector* result, StatusCallback done) { @@ -97,7 +101,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { return NewParallelMapIterator( {this, strings::StrCat(prefix, "::ParallelMap")}, input_, - std::move(map_func), num_parallel_calls_); + std::move(init_func), std::move(map_func), num_parallel_calls_); } const DataTypeVector& output_dtypes() const override { diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 4d32b719a42..61f8139b9e7 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -26,10 +26,12 @@ class ParallelMapIterator : public DatasetBaseIterator { public: explicit ParallelMapIterator( const typename DatasetBaseIterator::BaseParams& params, - const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, - int32 num_parallel_calls) + const DatasetBase* input_dataset, + std::function init_func, + ParallelMapIteratorFunction map_func, int32 num_parallel_calls) : DatasetBaseIterator(params), input_dataset_(input_dataset), + init_func_(std::move(init_func)), map_func_(std::move(map_func)), num_parallel_calls_(num_parallel_calls) {} @@ -50,7 +52,12 @@ class ParallelMapIterator : public DatasetBaseIterator { } Status Initialize(IteratorContext* ctx) override { - return input_dataset_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + input_dataset_->MakeIterator(ctx, prefix(), &input_impl_)); + if (init_func_) { + TF_RETURN_IF_ERROR(init_func_(ctx)); + } + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -285,6 +292,7 @@ class ParallelMapIterator : public DatasetBaseIterator { } const DatasetBase* const input_dataset_; // Not owned. + const std::function init_func_; const ParallelMapIteratorFunction map_func_; const int32 num_parallel_calls_; // Used for coordination between the main thread and the runner thread. @@ -311,8 +319,18 @@ std::unique_ptr NewParallelMapIterator( const DatasetBaseIterator::BaseParams& params, const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, int32 num_parallel_calls) { - return std::unique_ptr(new ParallelMapIterator( - params, input_dataset, std::move(map_func), num_parallel_calls)); + return NewParallelMapIterator(params, input_dataset, nullptr, + std::move(map_func), num_parallel_calls); +} + +std::unique_ptr NewParallelMapIterator( + const DatasetBaseIterator::BaseParams& params, + const DatasetBase* input_dataset, + std::function init_func, + ParallelMapIteratorFunction map_func, int32 num_parallel_calls) { + return std::unique_ptr( + new ParallelMapIterator(params, input_dataset, std::move(init_func), + std::move(map_func), num_parallel_calls)); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h index 2ce36c38690..7e6cc586f30 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.h +++ b/tensorflow/core/kernels/data/parallel_map_iterator.h @@ -33,7 +33,15 @@ using ParallelMapIteratorFunction = std::vector*, StatusCallback)>; // Returns a new iterator that applies `map_func` to the elements of -// `input_dataset` using the given degree of parallelism. +// `input_dataset` using the given degree of parallelism. `init_func` (if +// specified) will be executed when the iterator is initialized (see +// `IteratorBase::Initialize()`) and enables the user to specify error checking +// logic that can fail early. +std::unique_ptr NewParallelMapIterator( + const DatasetBaseIterator::BaseParams& params, + const DatasetBase* input_dataset, + std::function init_func, + ParallelMapIteratorFunction map_func, int32 num_parallel_calls); std::unique_ptr NewParallelMapIterator( const DatasetBaseIterator::BaseParams& params, const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 5e9ace3486e..299949b99f9 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -172,32 +172,39 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { class ForeverIterator : public DatasetIterator { public: explicit ForeverIterator(const Params& params) - : DatasetIterator(params), input_impl_(nullptr) {} + : DatasetIterator(params), + input_impl_(nullptr), + first_call_(true) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. do { - bool first_call = false; if (!input_impl_) { - first_call = true; TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); } - TF_RETURN_IF_ERROR( - input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); - if (!*end_of_sequence) { + Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + if (first_call_ && *end_of_sequence) { + // If the first call to GetNext() fails because the end + // of sequence has been reached, we terminate the + // iteration immediately. (Otherwise, this iterator + // would loop infinitely and never produce a value.) + input_impl_.reset(); return Status::OK(); + } + first_call_ = false; + if (!*end_of_sequence) { + return s; } else { input_impl_.reset(); - if (first_call) { - // If the first call to GetNext() fails because the end - // of sequence has been reached, we terminate the - // iteration immediately. (Otherwise, this iterator - // would loop infinitely and never produce a value.) - return Status::OK(); - } + first_call_ = true; } } while (true); } @@ -205,7 +212,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - if (input_impl_) + if (!first_call_) TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); else TF_RETURN_IF_ERROR( @@ -218,10 +225,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { mutex_lock l(mu_); if (reader->Contains(full_name("uninitialized"))) { input_impl_.reset(); + first_call_ = true; } else { TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + first_call_ = false; } return Status::OK(); } @@ -229,6 +238,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { private: mutex mu_; std::unique_ptr input_impl_ GUARDED_BY(mu_); + bool first_call_ GUARDED_BY(mu_); }; const int64 count_; diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index e4cb31e2b2e..5d3319b19fa 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -153,7 +153,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { state_(params.dataset->initial_state_) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 637bde9ae4e..52b4320bf1b 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -24,6 +24,7 @@ import warnings import numpy as np +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -31,6 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops @@ -673,6 +675,36 @@ class MapDatasetTest(test.TestCase): r"Dataset.map\(\): None."): _ = dataset.map(lambda x: None) + def testBrokenFunctionErrorOnInitialization(self): + dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0, 3.0]) + + def broken_function(_): + """A function deliberately designed to fail on instantiation.""" + value = [] + tensor_value = attr_value_pb2.AttrValue() + tensor_value.tensor.CopyFrom( + tensor_util.make_tensor_proto( + value, dtype=dtypes.float32, shape=[0], verify_shape=False)) + dtype_value = attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum) + + # Create a "Const" op with a `tf.float32` value and a `tf.int32` type + # attr. + const_tensor = ops.get_default_graph().create_op( + "Const", [], [dtypes.int32], + attrs={ + "value": tensor_value, + "dtype": dtype_value + }, + name="BrokenConst").outputs[0] + return const_tensor + + dataset = dataset.map(broken_function) + iterator = dataset.make_initializable_iterator() + + with self.test_session() as sess: + with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"): + sess.run(iterator.initializer) + class MapDatasetBenchmark(test.Benchmark):