diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 315f96d4b5d..dc2663d1e0c 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include #include "tensorflow/core/common_runtime/function.h" @@ -21,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -43,12 +45,7 @@ namespace { // // Furthermore, this class favors modularity over extended functionality. In // particular, it refrains from implementing configurable buffering of output -// elements and prefetching of input iterators, relying on other parts of -// tf.data to provide this functionality if necessary. -// -// The above design choices were made with automated optimizations in mind, -// isolating the degree of parallelism as the single tunable knob of this -// implementation. +// elements and prefetching of input iterators. class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { public: explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx) @@ -237,27 +234,20 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { - std::shared_ptr result; - do { - result.reset(); - { - mutex_lock l(*mu_); - EnsureRunnerThreadStarted(ctx); - while (ShouldWait(&result)) { - RecordStop(ctx); - cond_var_->wait(l); - RecordStart(ctx); - } - if (!result) { - *end_of_sequence = true; - return Status::OK(); - } + std::shared_ptr result; + { + mutex_lock l(*mu_); + EnsureThreadsStarted(ctx); + while (!Consume(&result)) { + RecordStop(ctx); + cond_var_->wait(l); + RecordStart(ctx); } - RecordStop(ctx); - result->notification.WaitForNotification(); - RecordStart(ctx); - } while (result->skip); - + } + if (!result) { + *end_of_sequence = true; + return Status::OK(); + } if (result->status.ok()) { *out_tensors = std::move(result->return_values); RecordBufferDequeue(ctx, *out_tensors); @@ -281,37 +271,22 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { while (num_calls_ > 0) { cond_var_->wait(l); } - CHECK_EQ(num_calls_, 0); + DCHECK_EQ(num_calls_, 0); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name("invocation_results.size"), invocation_results_.size())); - for (size_t i = 0; i < invocation_results_.size(); i++) { - std::shared_ptr result = invocation_results_[i]; - TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status)); - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat("invocation_results[", i, "].size")), - result->return_values.size())); - for (size_t j = 0; j < result->return_values.size(); j++) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name( - strings::StrCat("invocation_results[", i, "][", j, "]")), - result->return_values[j])); - } - if (result->skip) { - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat("invocation_results[", i, "].skip")), - "")); - } - } + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("block_index"), block_index_)); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("cycle_index"), cycle_index_)); if (end_of_input_) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("end_of_input"), "")); } + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("element_id_counter"), + element_id_counter_)); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("num_open"), num_open_)); TF_RETURN_IF_ERROR(WriteCurrentElements(writer)); + TF_RETURN_IF_ERROR(WriteFutureElements(writer)); return Status::OK(); } @@ -319,200 +294,207 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { IteratorStateReader* reader) override { mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - int64 invocation_results_size; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name("invocation_results.size"), &invocation_results_size)); - for (size_t i = 0; i < invocation_results_size; i++) { - std::shared_ptr result(new InvocationResult()); - invocation_results_.push_back(result); - TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status)); - size_t num_return_values; - { - int64 size; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name(strings::StrCat("invocation_results[", i, "].size")), - &size)); - num_return_values = static_cast(size); - if (num_return_values != size) { - return errors::InvalidArgument(strings::StrCat( - full_name( - strings::StrCat("invocation_results[", i, "].size")), - ": ", size, " is not a valid value of type size_t.")); - } - } - result->return_values.reserve(num_return_values); - for (size_t j = 0; j < num_return_values; j++) { - result->return_values.emplace_back(); - TF_RETURN_IF_ERROR( - reader->ReadTensor(full_name(strings::StrCat( - "invocation_results[", i, "][", j, "]")), - &result->return_values.back())); - } - result->skip = reader->Contains( - full_name(strings::StrCat("invocation_results[", i, "].skip"))); - result->notification.Notify(); - } + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("block_index"), &block_index_)); TF_RETURN_IF_ERROR( reader->ReadScalar(full_name("cycle_index"), &cycle_index_)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("element_id_counter"), + &element_id_counter_)); if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true; TF_RETURN_IF_ERROR( reader->ReadScalar(full_name("num_open"), &num_open_)); TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader)); + TF_RETURN_IF_ERROR(ReadFutureElements(ctx, reader)); return Status::OK(); } private: + // Represents the result of fetching an element from a dataset. + struct Result { + Status status; + std::vector return_values; + // Indicates whether the result is ready to be consumed. + bool is_ready = false; + }; + + // The interleave transformation repeatedly inputs elements, applies the + // user-provided function to transform the input elements to datasets, and + // interleaves the elements of these datasets as its output. + // + // This structure represents an input element and derived state. struct Element { + // Unique identifier, needed to support checkpointing. + int64 id; + // The actual input element. + std::vector inputs; + // Iterator created from the input element. std::unique_ptr iterator; - std::vector inputs; // inputs for creating the iterator - bool in_use; + mutex mu; + // Buffer for storing the outputs of `iterator`. + std::deque> results GUARDED_BY(mu); + // Indicates whether the element is used by a worker thread. + bool in_use = false; }; - struct InvocationResult { - Notification notification; // used for coordination with the consumer - Status status; // the invocation status - std::vector return_values; // the invocation result values - bool skip; // if set the result should be skipped - }; + // Advances the position in the interleave cycle to the next cycle + // element. + void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + block_index_ = 0; + cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; + } - void EnsureRunnerThreadStarted(IteratorContext* ctx) + // Advances the position in the interleave cycle by one. + void AdvancePosition() EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + ++block_index_; + if (block_index_ == dataset()->block_length_) { + AdvanceToNextInCycle(); + } + } + + // Consumes a result (if available), returning an indication of whether + // a result is available. If `true` is returned, `result` either + // points to a valid result or is null if end of input has been reached. + bool Consume(std::shared_ptr* result) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - if (!runner_thread_) { - std::shared_ptr new_ctx(new IteratorContext(*ctx)); - runner_thread_.reset(ctx->env()->StartThread( - {}, "tf_data_parallel_interleave_runner", - [this, new_ctx]() { RunnerThread(new_ctx); })); + if (!sloppy_) { + return ConsumeHelper(result); + } + // If we are allowed to be sloppy (i.e. return results out of order), + // try to find an element in the cycle that has a result available. + for (int i = 0; i < dataset()->cycle_length_; ++i) { + if (ConsumeHelper(result)) { + return true; + } + AdvanceToNextInCycle(); + } + return false; + } + + bool ConsumeHelper(std::shared_ptr* result) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + while (true) { + std::shared_ptr element = current_elements_[cycle_index_]; + if (element) { + mutex_lock l(element->mu); + if (!element->results.empty()) { + if (element->results.front()->is_ready) { + // We found a result. + std::swap(*result, element->results.front()); + element->results.pop_front(); + AdvancePosition(); + cond_var_->notify_all(); + return true; + } else { + // Wait for the result to become ready. + return false; + } + } else if (!element->iterator) { + // We reached the end of input for this element. Reset + // it and move on to the next cycle element. + current_elements_[cycle_index_].reset(); + AdvanceToNextInCycle(); + cond_var_->notify_all(); + continue; + } else { + // Wait for the iterator to produce a result. + return false; + } + } else { + if (!future_elements_.empty() || !end_of_input_) { + // Wait for an element to be created. + return false; + } + // No new elements will be created; try to find a + // non-empty element in the cycle. + for (int i = 0; i < dataset()->cycle_length_; ++i) { + AdvanceToNextInCycle(); + if (current_elements_[cycle_index_]) { + break; + } + } + if (current_elements_[cycle_index_]) { + continue; + } + // End of input has been reached. + return true; + } } } - // Fetches up to `results.size()` outputs from the cycle element at - // position `cycle_index`. + // Manages current cycle elements, creating new iterators as needed and + // asynchronously fetching results from existing iterators. // - // If end of input is encountered, the `skip` field of the invocation - // result is used to identify results that should be skipped. - void FetchOutputs( - const std::shared_ptr& ctx, IteratorBase* iterator, - int64 cycle_index, - const std::vector>& results) - LOCKS_EXCLUDED(*mu_) { - RecordStart(ctx.get()); - auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); - bool end_of_input = false; - for (auto& result : results) { - if (!end_of_input) { - result->status = iterator->GetNext( - ctx.get(), &result->return_values, &end_of_input); - } - if (end_of_input) { - result->skip = true; - } - RecordBufferEnqueue(ctx.get(), result->return_values); - { - mutex_lock l(*mu_); - result->notification.Notify(); - cond_var_->notify_all(); - } - if (!result->status.ok()) { - break; - } - } - - mutex_lock l(*mu_); - current_elements_[cycle_index].in_use = false; - if (end_of_input) { - // Release the ownership of the cycle element iterator, closing the - // iterator if end of input was encountered. - current_elements_[cycle_index].iterator.reset(); - current_elements_[cycle_index].inputs.clear(); - num_open_--; - } - num_calls_--; - const auto& stats_aggregator = ctx->stats_aggregator(); - if (stats_aggregator) { - stats_aggregator->AddScalar( - strings::StrCat(key_prefix_, "::thread_utilization"), - static_cast(num_calls_) / - static_cast(num_parallel_calls_->value)); - } - cond_var_->notify_all(); - } - - // Method responsible for 1) creating iterators out of input elements, 2) - // determining the order in which elements are fetched from the iterators, - // and 3) scheduling the fetching of the elements to a threadpool. - // - // This method runs in the `runner_thread` background thread. - void RunnerThread(const std::shared_ptr& ctx) { + // This method runs in the `current_elements_manager_` background thread. + void CurrentElementsManager(const std::shared_ptr& ctx) { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { - return current_elements_[cycle_index_].in_use || - num_calls_ >= num_parallel_calls_->value || - invocation_results_.size() >= - dataset()->cycle_length_ * dataset()->block_length_; + const bool has_more_elements = + !future_elements_.empty() || !end_of_input_; + const int block_length = dataset()->block_length_; + bool all_elements_busy = true; + for (auto& element : current_elements_) { + if (!element) { + if (has_more_elements) { + all_elements_busy = false; + break; + } + } else { + mutex_lock l(element->mu); + if (!element->in_use && element->iterator && + element->results.size() < block_length) { + all_elements_busy = false; + break; + } + } + } + return all_elements_busy || num_calls_ >= num_parallel_calls_->value; }; while (true) { mutex_lock l(*mu_); + // Wait until this thread is cancelled, the end of input has been - // reached, or the cycle element at the `cycle_index_` position is - // not in use and there is space in the `invocation_results_` queue. + // reached. while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) { RecordStop(ctx.get()); cond_var_->wait(l); RecordStart(ctx.get()); } - if (cancelled_ || (end_of_input_ && num_open_ == 0)) { + if (cancelled_ || + (future_elements_.empty() && end_of_input_ && num_open_ == 0)) { return; } - while ((!end_of_input_ || num_open_ > 0) && !busy()) { - if (!current_elements_[cycle_index_].iterator) { - // Try to create a new iterator from the next input element. - Status status = input_impl_->GetNext( - ctx.get(), ¤t_elements_[cycle_index_].inputs, - &end_of_input_); - if (!status.ok()) { - invocation_results_.emplace_back(new InvocationResult()); - std::shared_ptr& result = - invocation_results_.back(); - result->status.Update(status); - result->notification.Notify(); - break; - } - if (!end_of_input_) { - Status status = MakeIteratorFromInputElement( - ctx.get(), current_elements_[cycle_index_].inputs, - cycle_index_, *instantiated_captured_func_, prefix(), - ¤t_elements_[cycle_index_].iterator); - if (!status.ok()) { - invocation_results_.emplace_back(new InvocationResult()); - std::shared_ptr& result = - invocation_results_.back(); - result->status.Update(status); - result->notification.Notify(); - break; + for (int i = 0; i < dataset()->cycle_length_; ++i) { + int idx = (cycle_index_ + i) % dataset()->cycle_length_; + if (!current_elements_[idx]) { + if (!future_elements_.empty()) { + current_elements_[idx] = std::move(future_elements_.back()); + future_elements_.pop_back(); + } else { + current_elements_[idx] = MakeElement(ctx); + if (!current_elements_[idx]) { + continue; } - ++num_open_; } } - if (current_elements_[cycle_index_].iterator) { - // Pre-allocate invocation results for outputs to be fetched - // and then fetch the outputs asynchronously. - std::vector> results; - results.reserve(dataset()->block_length_); - for (int i = 0; i < dataset()->block_length_; ++i) { - invocation_results_.emplace_back(new InvocationResult()); - results.push_back(invocation_results_.back()); + std::shared_ptr element = current_elements_[idx]; + if (!element->in_use && element->iterator) { + int64 num_results; + { + mutex_lock l(element->mu); + num_results = + dataset()->block_length_ - element->results.size(); + } + if (num_results > 0) { + num_calls_++; + element->in_use = true; + thread_pool_->Schedule( + std::bind(&ParallelInterleaveIterator::FetchResults, this, + ctx, std::move(element), num_results)); } - num_calls_++; - current_elements_[cycle_index_].in_use = true; - thread_pool_->Schedule( - std::bind(&ParallelInterleaveIterator::FetchOutputs, this, - ctx, current_elements_[cycle_index_].iterator.get(), - cycle_index_, std::move(results))); } - cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; } const auto& stats_aggregator = ctx->stats_aggregator(); if (stats_aggregator) { @@ -525,56 +507,178 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } } - // Determines whether the caller needs to wait for a result. Upon - // returning false, `result` will either be NULL if end of input has been - // reached or point to the result. - bool ShouldWait(std::shared_ptr* result) + void EnsureThreadsStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - if (sloppy_) { - for (auto it = invocation_results_.begin(); - it != invocation_results_.end(); ++it) { - if ((*it)->notification.HasBeenNotified()) { - std::swap(*result, *it); - invocation_results_.erase(it); - cond_var_->notify_all(); - return false; - } - } - return !invocation_results_.empty() || - (!end_of_input_ || num_open_ > 0); - } else { - if (!invocation_results_.empty()) { - std::swap(*result, invocation_results_.front()); - invocation_results_.pop_front(); - cond_var_->notify_all(); - return false; - } - return (!end_of_input_ || num_open_ > 0); + if (!current_elements_manager_) { + auto new_ctx = std::make_shared(*ctx); + current_elements_manager_ = + WrapUnique(ctx->env()->StartThread( + {}, "tf_data_parallel_interleave_current", + [this, new_ctx]() { CurrentElementsManager(new_ctx); })); + } + if (!future_elements_manager_) { + auto new_ctx = std::make_shared(*ctx); + future_elements_manager_ = WrapUnique(ctx->env()->StartThread( + {}, "tf_data_parallel_interleave_future", + [this, new_ctx]() { FutureElementsManager(new_ctx); })); } } - Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, + // Fetches up to `dataset()->block_length_` results from `element`. + void FetchResults(const std::shared_ptr& ctx, + const std::shared_ptr& element, + int64 num_results) LOCKS_EXCLUDED(*mu_) { + RecordStart(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); + bool end_of_input = false; + for (int64 i = 0; i < num_results; ++i) { + auto result = std::make_shared(); + result->status = element->iterator->GetNext( + ctx.get(), &result->return_values, &end_of_input); + if (end_of_input) { + break; + } + RecordBufferEnqueue(ctx.get(), result->return_values); + mutex_lock l(*mu_); + mutex_lock l2(element->mu); + element->results.push_back(result); + result->is_ready = true; + cond_var_->notify_all(); + } + + mutex_lock l(*mu_); + // Release the ownership of the cycle element iterator. + element->in_use = false; + if (end_of_input) { + // Close the iterator if end of input was encountered. + element->iterator.reset(); + element->inputs.clear(); + --num_open_; + } + --num_calls_; + const auto& stats_aggregator = ctx->stats_aggregator(); + if (stats_aggregator) { + stats_aggregator->AddScalar( + strings::StrCat(key_prefix_, "::thread_utilization"), + static_cast(num_calls_) / + static_cast(num_parallel_calls_->value)); + } + cond_var_->notify_all(); + } + + // Manages futures cycle elements, creating new iterators as needed and + // asynchronously fetching results from existing iterators. + // + // This method runs in the `future_elements_manager_` background thread. + void FutureElementsManager(const std::shared_ptr& ctx) { + RecordStart(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { + return num_calls_ >= num_parallel_calls_->value || + future_elements_.size() >= dataset()->cycle_length_; + }; + while (true) { + mutex_lock l(*mu_); + + // Wait until this thread is cancelled, the end of input has been + // reached, or the cycle element at the `cycle_index_` position is + // not in use. + while (!cancelled_ && !end_of_input_ && busy()) { + RecordStop(ctx.get()); + cond_var_->wait(l); + RecordStart(ctx.get()); + } + + if (cancelled_ || end_of_input_) { + return; + } + + while (!end_of_input_ && !busy()) { + std::shared_ptr element = MakeElement(ctx); + if (!element) { + break; + } + future_elements_.push_front(element); + if (!element->iterator) { + continue; + } + ++num_calls_; + element->in_use = true; + thread_pool_->Schedule( + std::bind(&ParallelInterleaveIterator::FetchResults, this, ctx, + std::move(element), dataset()->block_length_)); + } + const auto& stats_aggregator = ctx->stats_aggregator(); + if (stats_aggregator) { + stats_aggregator->AddScalar( + strings::StrCat(key_prefix_, "::thread_utilization"), + static_cast(num_calls_) / + static_cast(num_parallel_calls_->value)); + } + cond_var_->notify_all(); + } + } + + // Creates a new element. + std::shared_ptr MakeElement( + const std::shared_ptr& ctx) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + auto element = std::make_shared(); + element->id = element_id_counter_++; + Status status = + input_impl_->GetNext(ctx.get(), &element->inputs, &end_of_input_); + if (!status.ok()) { + auto result = std::make_shared(); + result->is_ready = true; + result->status = status; + mutex_lock l(element->mu); + element->results.push_back(std::move(result)); + return element; + } + if (!end_of_input_) { + Status status = MakeIteratorFromInputElement( + ctx.get(), element->inputs, element->id, + *instantiated_captured_func_, prefix(), &element->iterator); + if (!status.ok()) { + auto result = std::make_shared(); + result->is_ready = true; + result->status = status; + mutex_lock l(element->mu); + element->results.push_back(std::move(result)); + return element; + } + ++num_open_; + } else { + element.reset(); + } + return element; + } + + Status WriteStatusLocked(IteratorStateWriter* writer, + const string& key_prefix, size_t idx, const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR(writer->WriteScalar( - CodeKey(index), static_cast(status.code()))); + CodeKey(key_prefix, idx), static_cast(status.code()))); if (!status.ok()) { - TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), - status.error_message())); + TF_RETURN_IF_ERROR(writer->WriteScalar( + ErrorMessageKey(key_prefix, idx), status.error_message())); } return Status::OK(); } - Status ReadStatusLocked(IteratorStateReader* reader, size_t index, + Status ReadStatusLocked(IteratorStateReader* reader, + const string& key_prefix, size_t idx, Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64 code_int; - TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); + TF_RETURN_IF_ERROR( + reader->ReadScalar(CodeKey(key_prefix, idx), &code_int)); error::Code code = static_cast(code_int); if (code != error::Code::OK) { string error_message; - TF_RETURN_IF_ERROR( - reader->ReadScalar(ErrorMessageKey(index), &error_message)); + TF_RETURN_IF_ERROR(reader->ReadScalar( + ErrorMessageKey(key_prefix, idx), &error_message)); *status = Status(code, error_message); } else { *status = Status::OK(); @@ -582,64 +686,178 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - string CodeKey(size_t index) { + string CodeKey(const string& key_prefix, size_t idx) { return full_name( - strings::StrCat("invocation_results[", index, "].code")); + strings::StrCat(key_prefix, ".results[", idx, "].code")); } - string ErrorMessageKey(size_t index) { + string ErrorMessageKey(const string& key_prefix, size_t idx) { return full_name( - strings::StrCat("invocation_results[", index, "].error_message")); + strings::StrCat(key_prefix, ".results[", idx, "].error_message")); + } + + Status WriteElement(std::shared_ptr element, int idx, + const string& key_prefix, IteratorStateWriter* writer) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + if (element->iterator) { + TF_RETURN_IF_ERROR(SaveInput(writer, element->iterator)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(key_prefix, "[", idx, "].id")), + element->id)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(key_prefix, "[", idx, "].inputs.size")), + element->inputs.size())); + for (int i = 0; i < element->inputs.size(); i++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name( + strings::StrCat(key_prefix, "[", idx, "].inputs[", i, "]")), + element->inputs[i])); + } + } + mutex_lock l(element->mu); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(key_prefix, "[", idx, "].results.size")), + element->results.size())); + for (size_t i = 0; i < element->results.size(); i++) { + std::shared_ptr result = element->results[i]; + TF_RETURN_IF_ERROR(WriteStatusLocked( + writer, strings::StrCat(key_prefix, "[", idx, "]"), i, + result->status)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i, + "].size")), + result->return_values.size())); + for (size_t j = 0; j < result->return_values.size(); j++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i, + "][", j, "]")), + result->return_values[j])); + } + if (result->is_ready) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i, + "].is_ready")), + "")); + } + } + return Status::OK(); } Status WriteCurrentElements(IteratorStateWriter* writer) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name("current_elements.size"), current_elements_.size())); for (int idx = 0; idx < current_elements_.size(); idx++) { - if (current_elements_[idx].iterator) { - TF_RETURN_IF_ERROR( - SaveInput(writer, current_elements_[idx].iterator)); - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name( - strings::StrCat("current_elements[", idx, "].inputs.size")), - current_elements_[idx].inputs.size())); - for (int i = 0; i < current_elements_[idx].inputs.size(); i++) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name(strings::StrCat("current_elements[", idx, - "].inputs[", i, "]")), - current_elements_[idx].inputs[i])); - } + if (current_elements_[idx]) { + TF_RETURN_IF_ERROR(WriteElement(current_elements_[idx], idx, + "current_elements", writer)); } } return Status::OK(); } + Status WriteFutureElements(IteratorStateWriter* writer) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name("future_elements.size"), future_elements_.size())); + for (int idx = 0; idx < future_elements_.size(); idx++) { + if (future_elements_[idx]) { + TF_RETURN_IF_ERROR(WriteElement(future_elements_[idx], idx, + "future_elements", writer)); + } + } + return Status::OK(); + } + + Status ReadElement(IteratorContext* ctx, IteratorStateReader* reader, + int idx, const string& key_prefix, + std::shared_ptr* out) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + if (!reader->Contains(full_name( + strings::StrCat(key_prefix, "[", idx, "].results.size")))) { + return Status::OK(); + } + auto element = std::make_shared(); + mutex_lock l(element->mu); + int64 results_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat(key_prefix, "[", idx, "].results.size")), + &results_size)); + element->results.resize(results_size); + for (size_t i = 0; i < results_size; i++) { + auto result = std::make_shared(); + TF_RETURN_IF_ERROR(ReadStatusLocked( + reader, strings::StrCat(key_prefix, "[", idx, "]"), i, + &result->status)); + int64 num_return_values; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i, + "].size")), + &num_return_values)); + result->return_values.reserve(num_return_values); + for (size_t j = 0; j < num_return_values; j++) { + result->return_values.emplace_back(); + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i, + "][", j, "]")), + &result->return_values.back())); + } + result->is_ready = reader->Contains(full_name(strings::StrCat( + key_prefix, "[", idx, "].results[", i, "].is_ready"))); + element->results[i] = std::move(result); + } + if (!reader->Contains(full_name( + strings::StrCat(key_prefix, "[", idx, "].inputs.size")))) { + element->iterator.reset(); + *out = std::move(element); + return Status::OK(); + } + int64 inputs_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat(key_prefix, "[", idx, "].inputs.size")), + &inputs_size)); + element->inputs.resize(inputs_size); + for (int i = 0; i < inputs_size; i++) { + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name( + strings::StrCat(key_prefix, "[", idx, "].inputs[", i, "]")), + &element->inputs[i])); + } + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat(key_prefix, "[", idx, "].id")), + &element->id)); + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( + ctx, element->inputs, element->id, + *instantiated_captured_func_.get(), prefix(), &element->iterator)); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, element->iterator)); + *out = std::move(element); + return Status::OK(); + } + Status ReadCurrentElements(IteratorContext* ctx, IteratorStateReader* reader) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + int64 size; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("current_elements.size"), &size)); + DCHECK_EQ(current_elements_.size(), size); for (int idx = 0; idx < current_elements_.size(); idx++) { - if (reader->Contains(full_name(strings::StrCat( - "current_elements[", idx, "].inputs.size")))) { - int64 inputs_size; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name( - strings::StrCat("current_elements[", idx, "].inputs.size")), - &inputs_size)); - current_elements_[idx].inputs.resize(inputs_size); - for (int i = 0; i < inputs_size; i++) { - TF_RETURN_IF_ERROR(reader->ReadTensor( - full_name(strings::StrCat("current_elements[", idx, - "].inputs[", i, "]")), - ¤t_elements_[idx].inputs[i])); - } - TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( - ctx, current_elements_[idx].inputs, idx, - *instantiated_captured_func_.get(), prefix(), - ¤t_elements_[idx].iterator)); - TF_RETURN_IF_ERROR( - RestoreInput(ctx, reader, current_elements_[idx].iterator)); - } else { - current_elements_[idx].iterator.reset(); - } + TF_RETURN_IF_ERROR(ReadElement(ctx, reader, idx, "current_elements", + ¤t_elements_[idx])); + } + return Status::OK(); + } + + Status ReadFutureElements(IteratorContext* ctx, + IteratorStateReader* reader) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + int64 size; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("future_elements.size"), &size)); + future_elements_.resize(size); + for (int idx = 0; idx < future_elements_.size(); idx++) { + TF_RETURN_IF_ERROR(ReadElement(ctx, reader, idx, "future_elements", + &future_elements_[idx])); } return Status::OK(); } @@ -648,12 +866,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // the worker threads. const std::shared_ptr mu_; - // Used for coordination between the main thread, the runner thread, and - // the worker threads. In particular, the runner thread should only - // schedule new calls when the number of in-flight calls is less than the - // user specified level of parallelism, there are slots available in the - // `invocation_results_` buffer, the current cycle element is not in use, - // and there are elements left to be fetched. + // Used for coordination between the main thread, the manager threads, and + // the threadpool threads. In particular, the managers thread should only + // schedule new calls into the threadpool when the number of in-flight + // calls is less than the user specified level of parallelism and there + // are slots available in the element `results` buffer. const std::shared_ptr cond_var_; // Identifies the maximum number of parallel calls. @@ -665,18 +882,17 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // Iterator for input elements. std::unique_ptr input_impl_ GUARDED_BY(*mu_); - // Identifies current cycle element. - int64 cycle_index_ = 0; + // Identifies position in the interleave cycle. + int64 block_index_ GUARDED_BY(*mu_) = 0; + int64 cycle_index_ GUARDED_BY(*mu_) = 0; - // Iterators for the current cycle elements. Concurrent access is - // protected by `element_in_use_`. - std::vector current_elements_ GUARDED_BY(*mu_); + // Elements of the current interleave cycle. + std::vector> current_elements_ GUARDED_BY(*mu_); - // Buffer for storing the invocation results. - std::deque> invocation_results_ - GUARDED_BY(*mu_); + // Elements to be used in the interleave cycle in the future. + std::deque> future_elements_ GUARDED_BY(*mu_); - // Identifies whether end of input has been reached. + // Identifies whether the global end of input has been reached. bool end_of_input_ GUARDED_BY(*mu_) = false; // Identifies the number of open iterators. @@ -686,9 +902,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { int64 num_calls_ GUARDED_BY(*mu_) = 0; std::unique_ptr thread_pool_; - std::unique_ptr runner_thread_ GUARDED_BY(*mu_); + std::unique_ptr current_elements_manager_ GUARDED_BY(*mu_); + std::unique_ptr future_elements_manager_ GUARDED_BY(*mu_); + int64 element_id_counter_ GUARDED_BY(*mu_) = 0; - // Identifies whether background activity should be cancelled. + // Identifies whether background threads should be cancelled. bool cancelled_ GUARDED_BY(*mu_) = false; string key_prefix_; std::unique_ptr instantiated_captured_func_; diff --git a/tensorflow/python/data/kernel_tests/interleave_test.py b/tensorflow/python/data/kernel_tests/interleave_test.py index 4fb61b2daf1..4b427ff5a41 100644 --- a/tensorflow/python/data/kernel_tests/interleave_test.py +++ b/tensorflow/python/data/kernel_tests/interleave_test.py @@ -17,19 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import threading - from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.ops import threading_options from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import script_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test @@ -78,47 +74,6 @@ def _interleave(lists, cycle_length, block_length): break -def _make_coordinated_sloppy_dataset(input_values, cycle_length, block_length, - num_parallel_calls): - """Produces a dataset iterator and events to control the order of elements. - - Args: - input_values: the values to generate lists to interleave from - cycle_length: the length of the interleave cycle - block_length: the length of the interleave block - num_parallel_calls: the degree of interleave parallelism - - Returns: - A dataset iterator (represented as `get_next` op) and events that can be - used to control the order of output elements. - """ - - # Set up threading events used to sequence when items are produced that - # are subsequently interleaved. These events allow us to deterministically - # simulate slowdowns and force sloppiness. - coordination_events = {i: threading.Event() for i in input_values} - - def map_py_fn(x): - coordination_events[x].wait() - coordination_events[x].clear() - return x * x - - def map_fn(x): - return script_ops.py_func(map_py_fn, [x], x.dtype) - - def interleave_fn(x): - dataset = dataset_ops.Dataset.from_tensors(x) - dataset = dataset.repeat(x) - return dataset.map(map_fn) - - options = dataset_ops.Options() - options.experimental_deterministic = False - dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat( - 2).interleave(interleave_fn, cycle_length, block_length, - num_parallel_calls).with_options(options) - return dataset, coordination_events - - def _repeat(values, count): """Produces a list of lists suitable for testing interleave. @@ -252,63 +207,37 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase): self.evaluate(get_next()) @parameterized.named_parameters( - ("1", np.int64([4, 5, 6]), 2, 1, 1), - ("2", np.int64([4, 5, 6]), 2, 1, 2), - ("3", np.int64([4, 5, 6]), 2, 3, 1), - ("4", np.int64([4, 5, 6]), 2, 3, 2), - ("5", np.int64([4, 5, 6]), 3, 2, 1), - ("6", np.int64([4, 5, 6]), 3, 2, 2), - ("7", np.int64([4, 5, 6]), 3, 2, 3), - ("8", np.int64([4, 0, 6]), 2, 3, 1), - ("9", np.int64([4, 0, 6]), 2, 3, 2), + ("1", np.int64([4, 5, 6]), 1, 3, 1), + ("2", np.int64([4, 5, 6]), 2, 1, 1), + ("3", np.int64([4, 5, 6]), 2, 1, 2), + ("4", np.int64([4, 5, 6]), 2, 3, 1), + ("5", np.int64([4, 5, 6]), 2, 3, 2), + ("6", np.int64([4, 5, 6]), 7, 2, 1), + ("7", np.int64([4, 5, 6]), 7, 2, 3), + ("8", np.int64([4, 5, 6]), 7, 2, 5), + ("9", np.int64([4, 5, 6]), 7, 2, 7), + ("10", np.int64([4, 0, 6]), 2, 3, 1), + ("11", np.int64([4, 0, 6]), 2, 3, 2), ) - def testSloppyInterleaveInOrder(self, input_values, cycle_length, + def testSloppyInterleaveDataset(self, input_values, cycle_length, block_length, num_parallel_calls): - dataset, coordination_events = _make_coordinated_sloppy_dataset( - input_values, cycle_length, block_length, num_parallel_calls) + count = 2 + dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat( + count).interleave( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length, num_parallel_calls) options = dataset_ops.Options() - options.experimental_threading = threading_options.ThreadingOptions() - options.experimental_threading.private_threadpool_size = ( - num_parallel_calls + 1) + options.experimental_deterministic = False dataset = dataset.with_options(options) - - get_next = self.getNext(dataset, requires_initialization=True) - for expected_element in _interleave( - _repeat(input_values, 2), cycle_length, block_length): - coordination_events[expected_element].set() - self.assertEqual(expected_element * expected_element, - self.evaluate(get_next())) - with self.assertRaises(errors.OutOfRangeError): - self.evaluate(get_next()) - - @parameterized.named_parameters( - ("1", np.int64([4, 5, 6]), 2, 1, 2), - ("2", np.int64([4, 5, 6]), 2, 3, 2), - ("3", np.int64([4, 5, 6]), 3, 2, 3), - ("4", np.int64([4, 0, 6]), 2, 3, 2), - ) - def testSloppyInterleaveOutOfOrder(self, input_values, cycle_length, - block_length, num_parallel_calls): - dataset, coordination_events = _make_coordinated_sloppy_dataset( - input_values, cycle_length, block_length, num_parallel_calls) - options = dataset_ops.Options() - options.experimental_threading = threading_options.ThreadingOptions() - options.experimental_threading.private_threadpool_size = ( - num_parallel_calls + 1) - dataset = dataset.with_options(options) - get_next = self.getNext(dataset, requires_initialization=True) - elements = [ - x for x in _interleave( - _repeat(input_values, 2), cycle_length, block_length) + expected_output = [ + element for element in _interleave( + _repeat(input_values, count), cycle_length, block_length) ] - for i in [1, 4, 7]: - elements[i], elements[i + 1] = elements[i + 1], elements[i] - - for element in elements: - coordination_events[element].set() - self.assertEqual(element * element, self.evaluate(get_next())) - with self.assertRaises(errors.OutOfRangeError): - self.evaluate(get_next()) + get_next = self.getNext(dataset) + actual_output = [] + for _ in range(len(expected_output)): + actual_output.append(self.evaluate(get_next())) + self.assertAllEqual(expected_output.sort(), actual_output.sort()) if __name__ == "__main__":