[tf.data] Move captured function instantiation to iterator initialization time.
Previously, a function instantiation error (e.g. in `Dataset.map()`) would lead to an error in each GetNext() call that attempted to use the function. Moving this to iterator instantiation time has the benefit that the error will be reported once when the initialization op is executed, which has a more helpful stack trace, since it should not be conflated with other potential op failures. PiperOrigin-RevId: 209633511
This commit is contained in:
parent
e28f9da84b
commit
9158b1b83a
@ -555,6 +555,12 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
|
|||||||
next_handle_++;
|
next_handle_++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (options.create_kernels_eagerly) {
|
||||||
|
Item* item;
|
||||||
|
TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item));
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -490,6 +490,11 @@ class FunctionLibraryRuntime {
|
|||||||
// Instantiates the function using an executor of the given type. If empty,
|
// Instantiates the function using an executor of the given type. If empty,
|
||||||
// the default TensorFlow executor will be used.
|
// the default TensorFlow executor will be used.
|
||||||
string executor_type;
|
string executor_type;
|
||||||
|
|
||||||
|
// If true, the runtime will attempt to create kernels for the function at
|
||||||
|
// instantiation time, rather than on the first run. This can be used to
|
||||||
|
// surface errors earlier.
|
||||||
|
bool create_kernels_eagerly = false;
|
||||||
};
|
};
|
||||||
typedef uint64 Handle;
|
typedef uint64 Handle;
|
||||||
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
|
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
|
||||||
|
@ -172,32 +172,18 @@ class BorrowedArgsCallFrame : public CallFrameBase {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status CapturedFunction::MaybeInstantiate(
|
Status CapturedFunction::GetHandle(IteratorContext* ctx,
|
||||||
IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle) {
|
FunctionLibraryRuntime::Handle* out_handle) {
|
||||||
mutex_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
if (lib_ == nullptr) {
|
if (lib_ == nullptr) {
|
||||||
// The context's runtime will be used for all subsequent calls.
|
return errors::Internal("Captured function \"", func_.name(),
|
||||||
lib_ = ctx->lib();
|
"\" was called before it was instantiated.");
|
||||||
DCHECK(f_handle_ == kInvalidHandle);
|
|
||||||
FunctionLibraryRuntime::InstantiateOptions inst_opts;
|
|
||||||
inst_opts.overlay_lib = ctx->function_library().get();
|
|
||||||
inst_opts.state_handle = std::to_string(random::New64());
|
|
||||||
TF_RETURN_IF_ERROR(lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
|
|
||||||
inst_opts, &f_handle_));
|
|
||||||
const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
|
|
||||||
if (fbody == nullptr) {
|
|
||||||
return errors::Internal("Failed to instantiate function body.");
|
|
||||||
}
|
}
|
||||||
ret_types_ = fbody->ret_types;
|
|
||||||
} else {
|
|
||||||
// TODO(mrry): Consider moving this under a shared lock, as it is
|
|
||||||
// the common case.
|
|
||||||
if (ctx->lib() != lib_) {
|
if (ctx->lib() != lib_) {
|
||||||
return errors::Internal(
|
return errors::Internal("Captured function \"", func_.name(),
|
||||||
"Captured function was called with a different "
|
"\" was called with a different "
|
||||||
"FunctionLibraryRuntime*, which is not permitted.");
|
"FunctionLibraryRuntime*, which is not permitted.");
|
||||||
}
|
}
|
||||||
}
|
|
||||||
*out_handle = f_handle_;
|
*out_handle = f_handle_;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -205,7 +191,7 @@ Status CapturedFunction::MaybeInstantiate(
|
|||||||
Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args,
|
Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args,
|
||||||
std::vector<Tensor>* rets) {
|
std::vector<Tensor>* rets) {
|
||||||
FunctionLibraryRuntime::Handle handle;
|
FunctionLibraryRuntime::Handle handle;
|
||||||
TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
|
TF_RETURN_IF_ERROR(GetHandle(ctx, &handle));
|
||||||
|
|
||||||
FunctionLibraryRuntime::Options f_opts;
|
FunctionLibraryRuntime::Options f_opts;
|
||||||
f_opts.step_id = CapturedFunction::generate_step_id();
|
f_opts.step_id = CapturedFunction::generate_step_id();
|
||||||
@ -242,7 +228,7 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
|
|||||||
const std::vector<Tensor>& args,
|
const std::vector<Tensor>& args,
|
||||||
std::vector<Tensor>* rets) {
|
std::vector<Tensor>* rets) {
|
||||||
FunctionLibraryRuntime::Handle handle;
|
FunctionLibraryRuntime::Handle handle;
|
||||||
TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
|
TF_RETURN_IF_ERROR(GetHandle(ctx, &handle));
|
||||||
|
|
||||||
FunctionLibraryRuntime::Options f_opts;
|
FunctionLibraryRuntime::Options f_opts;
|
||||||
f_opts.step_id = CapturedFunction::generate_step_id();
|
f_opts.step_id = CapturedFunction::generate_step_id();
|
||||||
@ -277,9 +263,30 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status CapturedFunction::Instantiate(IteratorContext* ctx) {
|
Status CapturedFunction::Instantiate(IteratorContext* ctx) {
|
||||||
FunctionLibraryRuntime::Handle unused_handle;
|
|
||||||
TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &unused_handle));
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
if (lib_ == nullptr) {
|
||||||
|
// The context's runtime will be used for all subsequent calls.
|
||||||
|
lib_ = ctx->lib();
|
||||||
|
DCHECK(f_handle_ == kInvalidHandle);
|
||||||
|
FunctionLibraryRuntime::InstantiateOptions inst_opts;
|
||||||
|
inst_opts.overlay_lib = ctx->function_library().get();
|
||||||
|
inst_opts.state_handle = std::to_string(random::New64());
|
||||||
|
inst_opts.create_kernels_eagerly = true;
|
||||||
|
Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
|
||||||
|
inst_opts, &f_handle_));
|
||||||
|
TF_RETURN_IF_ERROR(s);
|
||||||
|
const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
|
||||||
|
if (fbody == nullptr) {
|
||||||
|
return errors::Internal("Failed to instantiate function body.");
|
||||||
|
}
|
||||||
|
ret_types_ = fbody->ret_types;
|
||||||
|
} else {
|
||||||
|
if (ctx->lib() != lib_) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Captured function was called with a different "
|
||||||
|
"FunctionLibraryRuntime*, which is not permitted.");
|
||||||
|
}
|
||||||
|
}
|
||||||
if (captured_runner_ == nullptr) {
|
if (captured_runner_ == nullptr) {
|
||||||
captured_runner_ = *ctx->runner();
|
captured_runner_ = *ctx->runner();
|
||||||
}
|
}
|
||||||
@ -343,7 +350,7 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
|
|||||||
// be deleted before `done` is called. Take care not to capture `ctx` in any
|
// be deleted before `done` is called. Take care not to capture `ctx` in any
|
||||||
// code that may execute asynchronously in this function.
|
// code that may execute asynchronously in this function.
|
||||||
FunctionLibraryRuntime::Handle handle;
|
FunctionLibraryRuntime::Handle handle;
|
||||||
Status s = MaybeInstantiate(ctx, &handle);
|
Status s = GetHandle(ctx, &handle);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
done(s);
|
done(s);
|
||||||
return;
|
return;
|
||||||
|
@ -116,7 +116,7 @@ class CapturedFunction {
|
|||||||
CapturedFunction(const NameAttrList& func,
|
CapturedFunction(const NameAttrList& func,
|
||||||
std::vector<Tensor> captured_inputs);
|
std::vector<Tensor> captured_inputs);
|
||||||
|
|
||||||
Status MaybeInstantiate(IteratorContext* ctx,
|
Status GetHandle(IteratorContext* ctx,
|
||||||
FunctionLibraryRuntime::Handle* out_handle);
|
FunctionLibraryRuntime::Handle* out_handle);
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
|
@ -149,7 +149,9 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
: DatasetIterator<FilterDatasetBase>(params) {}
|
: DatasetIterator<FilterDatasetBase>(params) {}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
return dataset()->captured_func_->Instantiate(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
@ -129,7 +129,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
: DatasetIterator<Dataset>(params) {}
|
: DatasetIterator<Dataset>(params) {}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
return dataset()->captured_func_->Instantiate(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -80,20 +81,20 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
|
TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx));
|
||||||
|
TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx));
|
||||||
|
TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
|
||||||
if (!initialized_) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
|
|
||||||
// Explicitly instantiate the finalize function here so that
|
|
||||||
// we can invoke it in the destructor.
|
|
||||||
TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
|
|
||||||
initialized_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (finalized_) {
|
if (finalized_) {
|
||||||
*end_of_sequence = true;
|
*end_of_sequence = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -121,7 +122,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
bool initialized_ GUARDED_BY(mu_) = false;
|
|
||||||
bool finalized_ GUARDED_BY(mu_) = false;
|
bool finalized_ GUARDED_BY(mu_) = false;
|
||||||
std::vector<Tensor> state_ GUARDED_BY(mu_);
|
std::vector<Tensor> state_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
#include "tensorflow/core/kernels/data/captured_function.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
@ -190,7 +190,14 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
: DatasetIterator<Dataset>(params) {}
|
: DatasetIterator<Dataset>(params) {}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
|
||||||
|
TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Instantiate(ctx));
|
||||||
|
TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->captured_finalize_func_->Instantiate(ctx));
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
@ -205,7 +205,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
: DatasetIterator<Dataset>(params) {}
|
: DatasetIterator<Dataset>(params) {}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
|
||||||
|
TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->captured_window_size_func_->Instantiate(ctx));
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -156,7 +155,9 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
args_list_(params.dataset->cycle_length_) {}
|
args_list_(params.dataset->cycle_length_) {}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
return dataset()->captured_func_->Instantiate(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
@ -104,9 +104,8 @@ class IteratorResource : public ResourceBase {
|
|||||||
bool* end_of_sequence) {
|
bool* end_of_sequence) {
|
||||||
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
|
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
|
||||||
if (captured_iterator) {
|
if (captured_iterator) {
|
||||||
if (lib_ != nullptr) {
|
CHECK_NOTNULL(lib_);
|
||||||
ctx->set_lib(lib_);
|
ctx->set_lib(lib_);
|
||||||
}
|
|
||||||
return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence);
|
return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence);
|
||||||
} else {
|
} else {
|
||||||
return errors::FailedPrecondition(
|
return errors::FailedPrecondition(
|
||||||
@ -162,8 +161,10 @@ class IteratorResource : public ResourceBase {
|
|||||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
|
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
|
||||||
|
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
|
IteratorContext iter_ctx(ctx);
|
||||||
|
iter_ctx.set_lib(lib);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator));
|
dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
|
||||||
TF_RETURN_IF_ERROR(set_iterator(std::move(iterator)));
|
TF_RETURN_IF_ERROR(set_iterator(std::move(iterator)));
|
||||||
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
|
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
|
||||||
|
|
||||||
@ -198,6 +199,8 @@ class IteratorResource : public ResourceBase {
|
|||||||
return lib_def_;
|
return lib_def_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FunctionLibraryRuntime* function_library_runtime() { return lib_; }
|
||||||
|
|
||||||
// Transfers ownership of iterator to this. This method is thread-safe.
|
// Transfers ownership of iterator to this. This method is thread-safe.
|
||||||
Status set_iterator(std::unique_ptr<IteratorBase> iterator) {
|
Status set_iterator(std::unique_ptr<IteratorBase> iterator) {
|
||||||
if (iterator) {
|
if (iterator) {
|
||||||
@ -612,8 +615,10 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) {
|
|||||||
core::ScopedUnref unref(iterator_resource);
|
core::ScopedUnref unref(iterator_resource);
|
||||||
|
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
|
IteratorContext iter_ctx(ctx);
|
||||||
|
iter_ctx.set_lib(iterator_resource->function_library_runtime());
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator));
|
ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
|
||||||
OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
|
OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -837,8 +842,10 @@ class OneShotIteratorOp : public AsyncOpKernel {
|
|||||||
DatasetBase* dataset;
|
DatasetBase* dataset;
|
||||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
|
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
|
||||||
std::unique_ptr<IteratorBase> iter;
|
std::unique_ptr<IteratorBase> iter;
|
||||||
|
IteratorContext iter_ctx(ctx);
|
||||||
|
iter_ctx.set_lib(lib);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iter));
|
dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iter));
|
||||||
TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
|
TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
|
||||||
|
|
||||||
(*iterator)->Ref();
|
(*iterator)->Ref();
|
||||||
|
@ -204,7 +204,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
return dataset()->captured_func_->Instantiate(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
@ -127,7 +127,9 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
: DatasetIterator<Dataset>(params) {}
|
: DatasetIterator<Dataset>(params) {}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
return dataset()->captured_func_->Instantiate(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
@ -142,8 +142,15 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
: DatasetIterator<Dataset>(params) {}
|
: DatasetIterator<Dataset>(params) {}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->optimized_input_->MakeIterator(ctx, prefix(),
|
IteratorContext::Params params;
|
||||||
&input_impl_);
|
params.env = ctx->env();
|
||||||
|
params.runner = *(ctx->runner());
|
||||||
|
params.stats_aggregator_getter = ctx->stats_aggregator_getter();
|
||||||
|
params.lib = ctx->lib();
|
||||||
|
params.function_library = dataset()->flib_def_;
|
||||||
|
params.allocator_getter = ctx->allocator_getter();
|
||||||
|
return dataset()->optimized_input_->MakeIterator(
|
||||||
|
IteratorContext(params), prefix(), &input_impl_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
@ -251,7 +251,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
return dataset()->captured_func_->Instantiate(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
// It is implemented so that it matches the deterministic interleave
|
// It is implemented so that it matches the deterministic interleave
|
||||||
|
@ -88,6 +88,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||||
const string& prefix) const override {
|
const string& prefix) const override {
|
||||||
|
auto init_func = [this](IteratorContext* ctx) {
|
||||||
|
return captured_func_->Instantiate(ctx);
|
||||||
|
};
|
||||||
|
|
||||||
auto map_func = [this](IteratorContext* ctx,
|
auto map_func = [this](IteratorContext* ctx,
|
||||||
std::vector<Tensor> input_element,
|
std::vector<Tensor> input_element,
|
||||||
std::vector<Tensor>* result, StatusCallback done) {
|
std::vector<Tensor>* result, StatusCallback done) {
|
||||||
@ -97,7 +101,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
return NewParallelMapIterator(
|
return NewParallelMapIterator(
|
||||||
{this, strings::StrCat(prefix, "::ParallelMap")}, input_,
|
{this, strings::StrCat(prefix, "::ParallelMap")}, input_,
|
||||||
std::move(map_func), num_parallel_calls_);
|
std::move(init_func), std::move(map_func), num_parallel_calls_);
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataTypeVector& output_dtypes() const override {
|
const DataTypeVector& output_dtypes() const override {
|
||||||
|
@ -26,10 +26,12 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
|||||||
public:
|
public:
|
||||||
explicit ParallelMapIterator(
|
explicit ParallelMapIterator(
|
||||||
const typename DatasetBaseIterator::BaseParams& params,
|
const typename DatasetBaseIterator::BaseParams& params,
|
||||||
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
|
const DatasetBase* input_dataset,
|
||||||
int32 num_parallel_calls)
|
std::function<Status(IteratorContext*)> init_func,
|
||||||
|
ParallelMapIteratorFunction map_func, int32 num_parallel_calls)
|
||||||
: DatasetBaseIterator(params),
|
: DatasetBaseIterator(params),
|
||||||
input_dataset_(input_dataset),
|
input_dataset_(input_dataset),
|
||||||
|
init_func_(std::move(init_func)),
|
||||||
map_func_(std::move(map_func)),
|
map_func_(std::move(map_func)),
|
||||||
num_parallel_calls_(num_parallel_calls) {}
|
num_parallel_calls_(num_parallel_calls) {}
|
||||||
|
|
||||||
@ -50,7 +52,12 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return input_dataset_->MakeIterator(ctx, prefix(), &input_impl_);
|
TF_RETURN_IF_ERROR(
|
||||||
|
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
if (init_func_) {
|
||||||
|
TF_RETURN_IF_ERROR(init_func_(ctx));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||||
@ -285,6 +292,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const DatasetBase* const input_dataset_; // Not owned.
|
const DatasetBase* const input_dataset_; // Not owned.
|
||||||
|
const std::function<Status(IteratorContext*)> init_func_;
|
||||||
const ParallelMapIteratorFunction map_func_;
|
const ParallelMapIteratorFunction map_func_;
|
||||||
const int32 num_parallel_calls_;
|
const int32 num_parallel_calls_;
|
||||||
// Used for coordination between the main thread and the runner thread.
|
// Used for coordination between the main thread and the runner thread.
|
||||||
@ -311,8 +319,18 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
|
|||||||
const DatasetBaseIterator::BaseParams& params,
|
const DatasetBaseIterator::BaseParams& params,
|
||||||
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
|
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
|
||||||
int32 num_parallel_calls) {
|
int32 num_parallel_calls) {
|
||||||
return std::unique_ptr<IteratorBase>(new ParallelMapIterator(
|
return NewParallelMapIterator(params, input_dataset, nullptr,
|
||||||
params, input_dataset, std::move(map_func), num_parallel_calls));
|
std::move(map_func), num_parallel_calls);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorBase> NewParallelMapIterator(
|
||||||
|
const DatasetBaseIterator::BaseParams& params,
|
||||||
|
const DatasetBase* input_dataset,
|
||||||
|
std::function<Status(IteratorContext*)> init_func,
|
||||||
|
ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
|
||||||
|
return std::unique_ptr<IteratorBase>(
|
||||||
|
new ParallelMapIterator(params, input_dataset, std::move(init_func),
|
||||||
|
std::move(map_func), num_parallel_calls));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -33,7 +33,15 @@ using ParallelMapIteratorFunction =
|
|||||||
std::vector<Tensor>*, StatusCallback)>;
|
std::vector<Tensor>*, StatusCallback)>;
|
||||||
|
|
||||||
// Returns a new iterator that applies `map_func` to the elements of
|
// Returns a new iterator that applies `map_func` to the elements of
|
||||||
// `input_dataset` using the given degree of parallelism.
|
// `input_dataset` using the given degree of parallelism. `init_func` (if
|
||||||
|
// specified) will be executed when the iterator is initialized (see
|
||||||
|
// `IteratorBase::Initialize()`) and enables the user to specify error checking
|
||||||
|
// logic that can fail early.
|
||||||
|
std::unique_ptr<IteratorBase> NewParallelMapIterator(
|
||||||
|
const DatasetBaseIterator::BaseParams& params,
|
||||||
|
const DatasetBase* input_dataset,
|
||||||
|
std::function<Status(IteratorContext*)> init_func,
|
||||||
|
ParallelMapIteratorFunction map_func, int32 num_parallel_calls);
|
||||||
std::unique_ptr<IteratorBase> NewParallelMapIterator(
|
std::unique_ptr<IteratorBase> NewParallelMapIterator(
|
||||||
const DatasetBaseIterator::BaseParams& params,
|
const DatasetBaseIterator::BaseParams& params,
|
||||||
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
|
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
|
||||||
|
@ -172,32 +172,39 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
class ForeverIterator : public DatasetIterator<Dataset> {
|
class ForeverIterator : public DatasetIterator<Dataset> {
|
||||||
public:
|
public:
|
||||||
explicit ForeverIterator(const Params& params)
|
explicit ForeverIterator(const Params& params)
|
||||||
: DatasetIterator<Dataset>(params), input_impl_(nullptr) {}
|
: DatasetIterator<Dataset>(params),
|
||||||
|
input_impl_(nullptr),
|
||||||
|
first_call_(true) {}
|
||||||
|
|
||||||
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||||
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
|
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
|
||||||
do {
|
do {
|
||||||
bool first_call = false;
|
|
||||||
if (!input_impl_) {
|
if (!input_impl_) {
|
||||||
first_call = true;
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(
|
Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||||
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
if (first_call_ && *end_of_sequence) {
|
||||||
if (!*end_of_sequence) {
|
|
||||||
return Status::OK();
|
|
||||||
} else {
|
|
||||||
input_impl_.reset();
|
|
||||||
if (first_call) {
|
|
||||||
// If the first call to GetNext() fails because the end
|
// If the first call to GetNext() fails because the end
|
||||||
// of sequence has been reached, we terminate the
|
// of sequence has been reached, we terminate the
|
||||||
// iteration immediately. (Otherwise, this iterator
|
// iteration immediately. (Otherwise, this iterator
|
||||||
// would loop infinitely and never produce a value.)
|
// would loop infinitely and never produce a value.)
|
||||||
|
input_impl_.reset();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
first_call_ = false;
|
||||||
|
if (!*end_of_sequence) {
|
||||||
|
return s;
|
||||||
|
} else {
|
||||||
|
input_impl_.reset();
|
||||||
|
first_call_ = true;
|
||||||
}
|
}
|
||||||
} while (true);
|
} while (true);
|
||||||
}
|
}
|
||||||
@ -205,7 +212,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
protected:
|
protected:
|
||||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (input_impl_)
|
if (!first_call_)
|
||||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||||
else
|
else
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -218,10 +225,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (reader->Contains(full_name("uninitialized"))) {
|
if (reader->Contains(full_name("uninitialized"))) {
|
||||||
input_impl_.reset();
|
input_impl_.reset();
|
||||||
|
first_call_ = true;
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||||
|
first_call_ = false;
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -229,6 +238,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
private:
|
private:
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||||
|
bool first_call_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
const int64 count_;
|
const int64 count_;
|
||||||
|
@ -153,7 +153,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
state_(params.dataset->initial_state_) {}
|
state_(params.dataset->initial_state_) {}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
return dataset()->captured_func_->Instantiate(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
@ -24,6 +24,7 @@ import warnings
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -31,6 +32,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import data_flow_ops
|
||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
@ -673,6 +675,36 @@ class MapDatasetTest(test.TestCase):
|
|||||||
r"Dataset.map\(\): None."):
|
r"Dataset.map\(\): None."):
|
||||||
_ = dataset.map(lambda x: None)
|
_ = dataset.map(lambda x: None)
|
||||||
|
|
||||||
|
def testBrokenFunctionErrorOnInitialization(self):
|
||||||
|
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0, 3.0])
|
||||||
|
|
||||||
|
def broken_function(_):
|
||||||
|
"""A function deliberately designed to fail on instantiation."""
|
||||||
|
value = []
|
||||||
|
tensor_value = attr_value_pb2.AttrValue()
|
||||||
|
tensor_value.tensor.CopyFrom(
|
||||||
|
tensor_util.make_tensor_proto(
|
||||||
|
value, dtype=dtypes.float32, shape=[0], verify_shape=False))
|
||||||
|
dtype_value = attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum)
|
||||||
|
|
||||||
|
# Create a "Const" op with a `tf.float32` value and a `tf.int32` type
|
||||||
|
# attr.
|
||||||
|
const_tensor = ops.get_default_graph().create_op(
|
||||||
|
"Const", [], [dtypes.int32],
|
||||||
|
attrs={
|
||||||
|
"value": tensor_value,
|
||||||
|
"dtype": dtype_value
|
||||||
|
},
|
||||||
|
name="BrokenConst").outputs[0]
|
||||||
|
return const_tensor
|
||||||
|
|
||||||
|
dataset = dataset.map(broken_function)
|
||||||
|
iterator = dataset.make_initializable_iterator()
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
|
||||||
|
sess.run(iterator.initializer)
|
||||||
|
|
||||||
|
|
||||||
class MapDatasetBenchmark(test.Benchmark):
|
class MapDatasetBenchmark(test.Benchmark):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user