[tf.data] Move captured function instantiation to iterator initialization time.

Previously, a function instantiation error (e.g. in `Dataset.map()`) would lead
to an error in each GetNext() call that attempted to use the function. Moving this
to iterator instantiation time has the benefit that the error will be reported
once when the initialization op is executed, which has a more helpful stack
trace, since it should not be conflated with other potential op failures.

PiperOrigin-RevId: 209633511
This commit is contained in:
Derek Murray 2018-08-21 11:48:37 -07:00 committed by TensorFlower Gardener
parent e28f9da84b
commit 9158b1b83a
22 changed files with 210 additions and 81 deletions

View File

@ -555,6 +555,12 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
next_handle_++; next_handle_++;
} }
} }
if (options.create_kernels_eagerly) {
Item* item;
TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item));
}
return Status::OK(); return Status::OK();
} }

View File

@ -490,6 +490,11 @@ class FunctionLibraryRuntime {
// Instantiates the function using an executor of the given type. If empty, // Instantiates the function using an executor of the given type. If empty,
// the default TensorFlow executor will be used. // the default TensorFlow executor will be used.
string executor_type; string executor_type;
// If true, the runtime will attempt to create kernels for the function at
// instantiation time, rather than on the first run. This can be used to
// surface errors earlier.
bool create_kernels_eagerly = false;
}; };
typedef uint64 Handle; typedef uint64 Handle;
virtual Status Instantiate(const string& function_name, AttrSlice attrs, virtual Status Instantiate(const string& function_name, AttrSlice attrs,

View File

@ -172,31 +172,17 @@ class BorrowedArgsCallFrame : public CallFrameBase {
} // namespace } // namespace
Status CapturedFunction::MaybeInstantiate( Status CapturedFunction::GetHandle(IteratorContext* ctx,
IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle) { FunctionLibraryRuntime::Handle* out_handle) {
mutex_lock l(mu_); tf_shared_lock l(mu_);
if (lib_ == nullptr) { if (lib_ == nullptr) {
// The context's runtime will be used for all subsequent calls. return errors::Internal("Captured function \"", func_.name(),
lib_ = ctx->lib(); "\" was called before it was instantiated.");
DCHECK(f_handle_ == kInvalidHandle); }
FunctionLibraryRuntime::InstantiateOptions inst_opts; if (ctx->lib() != lib_) {
inst_opts.overlay_lib = ctx->function_library().get(); return errors::Internal("Captured function \"", func_.name(),
inst_opts.state_handle = std::to_string(random::New64()); "\" was called with a different "
TF_RETURN_IF_ERROR(lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), "FunctionLibraryRuntime*, which is not permitted.");
inst_opts, &f_handle_));
const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
if (fbody == nullptr) {
return errors::Internal("Failed to instantiate function body.");
}
ret_types_ = fbody->ret_types;
} else {
// TODO(mrry): Consider moving this under a shared lock, as it is
// the common case.
if (ctx->lib() != lib_) {
return errors::Internal(
"Captured function was called with a different "
"FunctionLibraryRuntime*, which is not permitted.");
}
} }
*out_handle = f_handle_; *out_handle = f_handle_;
return Status::OK(); return Status::OK();
@ -205,7 +191,7 @@ Status CapturedFunction::MaybeInstantiate(
Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args, Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets) { std::vector<Tensor>* rets) {
FunctionLibraryRuntime::Handle handle; FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle)); TF_RETURN_IF_ERROR(GetHandle(ctx, &handle));
FunctionLibraryRuntime::Options f_opts; FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id(); f_opts.step_id = CapturedFunction::generate_step_id();
@ -242,7 +228,7 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
const std::vector<Tensor>& args, const std::vector<Tensor>& args,
std::vector<Tensor>* rets) { std::vector<Tensor>* rets) {
FunctionLibraryRuntime::Handle handle; FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle)); TF_RETURN_IF_ERROR(GetHandle(ctx, &handle));
FunctionLibraryRuntime::Options f_opts; FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id(); f_opts.step_id = CapturedFunction::generate_step_id();
@ -277,9 +263,30 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
} }
Status CapturedFunction::Instantiate(IteratorContext* ctx) { Status CapturedFunction::Instantiate(IteratorContext* ctx) {
FunctionLibraryRuntime::Handle unused_handle;
TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &unused_handle));
mutex_lock l(mu_); mutex_lock l(mu_);
if (lib_ == nullptr) {
// The context's runtime will be used for all subsequent calls.
lib_ = ctx->lib();
DCHECK(f_handle_ == kInvalidHandle);
FunctionLibraryRuntime::InstantiateOptions inst_opts;
inst_opts.overlay_lib = ctx->function_library().get();
inst_opts.state_handle = std::to_string(random::New64());
inst_opts.create_kernels_eagerly = true;
Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
inst_opts, &f_handle_));
TF_RETURN_IF_ERROR(s);
const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
if (fbody == nullptr) {
return errors::Internal("Failed to instantiate function body.");
}
ret_types_ = fbody->ret_types;
} else {
if (ctx->lib() != lib_) {
return errors::Internal(
"Captured function was called with a different "
"FunctionLibraryRuntime*, which is not permitted.");
}
}
if (captured_runner_ == nullptr) { if (captured_runner_ == nullptr) {
captured_runner_ = *ctx->runner(); captured_runner_ = *ctx->runner();
} }
@ -343,7 +350,7 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
// be deleted before `done` is called. Take care not to capture `ctx` in any // be deleted before `done` is called. Take care not to capture `ctx` in any
// code that may execute asynchronously in this function. // code that may execute asynchronously in this function.
FunctionLibraryRuntime::Handle handle; FunctionLibraryRuntime::Handle handle;
Status s = MaybeInstantiate(ctx, &handle); Status s = GetHandle(ctx, &handle);
if (!s.ok()) { if (!s.ok()) {
done(s); done(s);
return; return;

View File

@ -116,8 +116,8 @@ class CapturedFunction {
CapturedFunction(const NameAttrList& func, CapturedFunction(const NameAttrList& func,
std::vector<Tensor> captured_inputs); std::vector<Tensor> captured_inputs);
Status MaybeInstantiate(IteratorContext* ctx, Status GetHandle(IteratorContext* ctx,
FunctionLibraryRuntime::Handle* out_handle); FunctionLibraryRuntime::Handle* out_handle);
mutex mu_; mutex mu_;
const NameAttrList func_; const NameAttrList func_;

View File

@ -149,7 +149,9 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<FilterDatasetBase>(params) {} : DatasetIterator<FilterDatasetBase>(params) {}
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
} }
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,

View File

@ -129,7 +129,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {} : DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
} }
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
namespace tensorflow { namespace tensorflow {
@ -80,20 +81,20 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
} }
} }
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(
dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override { bool* end_of_sequence) override {
mutex_lock l(mu_); mutex_lock l(mu_);
if (!initialized_) {
TF_RETURN_IF_ERROR(
dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
// Explicitly instantiate the finalize function here so that
// we can invoke it in the destructor.
TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
initialized_ = true;
}
if (finalized_) { if (finalized_) {
*end_of_sequence = true; *end_of_sequence = true;
return Status::OK(); return Status::OK();
@ -121,7 +122,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
private: private:
mutex mu_; mutex mu_;
bool initialized_ GUARDED_BY(mu_) = false;
bool finalized_ GUARDED_BY(mu_) = false; bool finalized_ GUARDED_BY(mu_) = false;
std::vector<Tensor> state_ GUARDED_BY(mu_); std::vector<Tensor> state_ GUARDED_BY(mu_);
}; };

View File

@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ #define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/captured_function.h"
namespace tensorflow { namespace tensorflow {

View File

@ -190,7 +190,14 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {} : DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(
dataset()->captured_finalize_func_->Instantiate(ctx));
return Status::OK();
} }
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,

View File

@ -205,7 +205,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {} : DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(
dataset()->captured_window_size_func_->Instantiate(ctx));
return Status::OK();
} }
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,

View File

@ -1,4 +1,3 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
@ -156,7 +155,9 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
args_list_(params.dataset->cycle_length_) {} args_list_(params.dataset->cycle_length_) {}
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
} }
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) { void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {

View File

@ -104,9 +104,8 @@ class IteratorResource : public ResourceBase {
bool* end_of_sequence) { bool* end_of_sequence) {
std::shared_ptr<IteratorBase> captured_iterator(iterator_); std::shared_ptr<IteratorBase> captured_iterator(iterator_);
if (captured_iterator) { if (captured_iterator) {
if (lib_ != nullptr) { CHECK_NOTNULL(lib_);
ctx->set_lib(lib_); ctx->set_lib(lib_);
}
return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence); return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence);
} else { } else {
return errors::FailedPrecondition( return errors::FailedPrecondition(
@ -162,8 +161,10 @@ class IteratorResource : public ResourceBase {
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
IteratorContext iter_ctx(ctx);
iter_ctx.set_lib(lib);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator)); dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
TF_RETURN_IF_ERROR(set_iterator(std::move(iterator))); TF_RETURN_IF_ERROR(set_iterator(std::move(iterator)));
std::shared_ptr<IteratorBase> captured_iterator(iterator_); std::shared_ptr<IteratorBase> captured_iterator(iterator_);
@ -198,6 +199,8 @@ class IteratorResource : public ResourceBase {
return lib_def_; return lib_def_;
} }
FunctionLibraryRuntime* function_library_runtime() { return lib_; }
// Transfers ownership of iterator to this. This method is thread-safe. // Transfers ownership of iterator to this. This method is thread-safe.
Status set_iterator(std::unique_ptr<IteratorBase> iterator) { Status set_iterator(std::unique_ptr<IteratorBase> iterator) {
if (iterator) { if (iterator) {
@ -612,8 +615,10 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) {
core::ScopedUnref unref(iterator_resource); core::ScopedUnref unref(iterator_resource);
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
IteratorContext iter_ctx(ctx);
iter_ctx.set_lib(iterator_resource->function_library_runtime());
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator)); ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator))); OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
} }
@ -837,8 +842,10 @@ class OneShotIteratorOp : public AsyncOpKernel {
DatasetBase* dataset; DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset)); TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
std::unique_ptr<IteratorBase> iter; std::unique_ptr<IteratorBase> iter;
IteratorContext iter_ctx(ctx);
iter_ctx.set_lib(lib);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iter)); dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iter));
TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter))); TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
(*iterator)->Ref(); (*iterator)->Ref();

View File

@ -204,7 +204,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
} }
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
} }
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,

View File

@ -127,7 +127,9 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {} : DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
} }
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,

View File

@ -142,8 +142,15 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {} : DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
return dataset()->optimized_input_->MakeIterator(ctx, prefix(), IteratorContext::Params params;
&input_impl_); params.env = ctx->env();
params.runner = *(ctx->runner());
params.stats_aggregator_getter = ctx->stats_aggregator_getter();
params.lib = ctx->lib();
params.function_library = dataset()->flib_def_;
params.allocator_getter = ctx->allocator_getter();
return dataset()->optimized_input_->MakeIterator(
IteratorContext(params), prefix(), &input_impl_);
} }
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,

View File

@ -251,7 +251,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
} }
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
} }
// It is implemented so that it matches the deterministic interleave // It is implemented so that it matches the deterministic interleave

View File

@ -88,6 +88,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal( std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override { const string& prefix) const override {
auto init_func = [this](IteratorContext* ctx) {
return captured_func_->Instantiate(ctx);
};
auto map_func = [this](IteratorContext* ctx, auto map_func = [this](IteratorContext* ctx,
std::vector<Tensor> input_element, std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) { std::vector<Tensor>* result, StatusCallback done) {
@ -97,7 +101,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return NewParallelMapIterator( return NewParallelMapIterator(
{this, strings::StrCat(prefix, "::ParallelMap")}, input_, {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
std::move(map_func), num_parallel_calls_); std::move(init_func), std::move(map_func), num_parallel_calls_);
} }
const DataTypeVector& output_dtypes() const override { const DataTypeVector& output_dtypes() const override {

View File

@ -26,10 +26,12 @@ class ParallelMapIterator : public DatasetBaseIterator {
public: public:
explicit ParallelMapIterator( explicit ParallelMapIterator(
const typename DatasetBaseIterator::BaseParams& params, const typename DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, const DatasetBase* input_dataset,
int32 num_parallel_calls) std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls)
: DatasetBaseIterator(params), : DatasetBaseIterator(params),
input_dataset_(input_dataset), input_dataset_(input_dataset),
init_func_(std::move(init_func)),
map_func_(std::move(map_func)), map_func_(std::move(map_func)),
num_parallel_calls_(num_parallel_calls) {} num_parallel_calls_(num_parallel_calls) {}
@ -50,7 +52,12 @@ class ParallelMapIterator : public DatasetBaseIterator {
} }
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
return input_dataset_->MakeIterator(ctx, prefix(), &input_impl_); TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
if (init_func_) {
TF_RETURN_IF_ERROR(init_func_(ctx));
}
return Status::OK();
} }
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
@ -285,6 +292,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
} }
const DatasetBase* const input_dataset_; // Not owned. const DatasetBase* const input_dataset_; // Not owned.
const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_; const ParallelMapIteratorFunction map_func_;
const int32 num_parallel_calls_; const int32 num_parallel_calls_;
// Used for coordination between the main thread and the runner thread. // Used for coordination between the main thread and the runner thread.
@ -311,8 +319,18 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params, const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
int32 num_parallel_calls) { int32 num_parallel_calls) {
return std::unique_ptr<IteratorBase>(new ParallelMapIterator( return NewParallelMapIterator(params, input_dataset, nullptr,
params, input_dataset, std::move(map_func), num_parallel_calls)); std::move(map_func), num_parallel_calls);
}
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
return std::unique_ptr<IteratorBase>(
new ParallelMapIterator(params, input_dataset, std::move(init_func),
std::move(map_func), num_parallel_calls));
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -33,7 +33,15 @@ using ParallelMapIteratorFunction =
std::vector<Tensor>*, StatusCallback)>; std::vector<Tensor>*, StatusCallback)>;
// Returns a new iterator that applies `map_func` to the elements of // Returns a new iterator that applies `map_func` to the elements of
// `input_dataset` using the given degree of parallelism. // `input_dataset` using the given degree of parallelism. `init_func` (if
// specified) will be executed when the iterator is initialized (see
// `IteratorBase::Initialize()`) and enables the user to specify error checking
// logic that can fail early.
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls);
std::unique_ptr<IteratorBase> NewParallelMapIterator( std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params, const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,

View File

@ -172,32 +172,39 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
class ForeverIterator : public DatasetIterator<Dataset> { class ForeverIterator : public DatasetIterator<Dataset> {
public: public:
explicit ForeverIterator(const Params& params) explicit ForeverIterator(const Params& params)
: DatasetIterator<Dataset>(params), input_impl_(nullptr) {} : DatasetIterator<Dataset>(params),
input_impl_(nullptr),
first_call_(true) {}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override { bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
do { do {
bool first_call = false;
if (!input_impl_) { if (!input_impl_) {
first_call = true;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
} }
TF_RETURN_IF_ERROR( Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); if (first_call_ && *end_of_sequence) {
if (!*end_of_sequence) { // If the first call to GetNext() fails because the end
// of sequence has been reached, we terminate the
// iteration immediately. (Otherwise, this iterator
// would loop infinitely and never produce a value.)
input_impl_.reset();
return Status::OK(); return Status::OK();
}
first_call_ = false;
if (!*end_of_sequence) {
return s;
} else { } else {
input_impl_.reset(); input_impl_.reset();
if (first_call) { first_call_ = true;
// If the first call to GetNext() fails because the end
// of sequence has been reached, we terminate the
// iteration immediately. (Otherwise, this iterator
// would loop infinitely and never produce a value.)
return Status::OK();
}
} }
} while (true); } while (true);
} }
@ -205,7 +212,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
protected: protected:
Status SaveInternal(IteratorStateWriter* writer) override { Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_); mutex_lock l(mu_);
if (input_impl_) if (!first_call_)
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
else else
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -218,10 +225,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_); mutex_lock l(mu_);
if (reader->Contains(full_name("uninitialized"))) { if (reader->Contains(full_name("uninitialized"))) {
input_impl_.reset(); input_impl_.reset();
first_call_ = true;
} else { } else {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
first_call_ = false;
} }
return Status::OK(); return Status::OK();
} }
@ -229,6 +238,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
private: private:
mutex mu_; mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
bool first_call_ GUARDED_BY(mu_);
}; };
const int64 count_; const int64 count_;

View File

@ -153,7 +153,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
state_(params.dataset->initial_state_) {} state_(params.dataset->initial_state_) {}
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
} }
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,

View File

@ -24,6 +24,7 @@ import warnings
import numpy as np import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
@ -31,6 +32,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops from tensorflow.python.ops import functional_ops
@ -673,6 +675,36 @@ class MapDatasetTest(test.TestCase):
r"Dataset.map\(\): None."): r"Dataset.map\(\): None."):
_ = dataset.map(lambda x: None) _ = dataset.map(lambda x: None)
def testBrokenFunctionErrorOnInitialization(self):
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0, 3.0])
def broken_function(_):
"""A function deliberately designed to fail on instantiation."""
value = []
tensor_value = attr_value_pb2.AttrValue()
tensor_value.tensor.CopyFrom(
tensor_util.make_tensor_proto(
value, dtype=dtypes.float32, shape=[0], verify_shape=False))
dtype_value = attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum)
# Create a "Const" op with a `tf.float32` value and a `tf.int32` type
# attr.
const_tensor = ops.get_default_graph().create_op(
"Const", [], [dtypes.int32],
attrs={
"value": tensor_value,
"dtype": dtype_value
},
name="BrokenConst").outputs[0]
return const_tensor
dataset = dataset.map(broken_function)
iterator = dataset.make_initializable_iterator()
with self.test_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
sess.run(iterator.initializer)
class MapDatasetBenchmark(test.Benchmark): class MapDatasetBenchmark(test.Benchmark):