[tf.data] Adding prefetching of input iterators for parallel interleave.

PiperOrigin-RevId: 227911701
This commit is contained in:
Jiri Simsa 2019-01-04 14:10:40 -08:00 committed by TensorFlower Gardener
parent efe565bc09
commit 41e333f019
2 changed files with 533 additions and 386 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <atomic>
#include <deque>
#include <memory>
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
@ -21,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/threadpool.h"
@ -43,12 +45,7 @@ namespace {
//
// Furthermore, this class favors modularity over extended functionality. In
// particular, it refrains from implementing configurable buffering of output
// elements and prefetching of input iterators, relying on other parts of
// tf.data to provide this functionality if necessary.
//
// The above design choices were made with automated optimizations in mind,
// isolating the degree of parallelism as the single tunable knob of this
// implementation.
// elements and prefetching of input iterators.
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
@ -237,27 +234,20 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
std::shared_ptr<InvocationResult> result;
do {
result.reset();
{
mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (ShouldWait(&result)) {
RecordStop(ctx);
cond_var_->wait(l);
RecordStart(ctx);
}
if (!result) {
*end_of_sequence = true;
return Status::OK();
}
std::shared_ptr<Result> result;
{
mutex_lock l(*mu_);
EnsureThreadsStarted(ctx);
while (!Consume(&result)) {
RecordStop(ctx);
cond_var_->wait(l);
RecordStart(ctx);
}
RecordStop(ctx);
result->notification.WaitForNotification();
RecordStart(ctx);
} while (result->skip);
}
if (!result) {
*end_of_sequence = true;
return Status::OK();
}
if (result->status.ok()) {
*out_tensors = std::move(result->return_values);
RecordBufferDequeue(ctx, *out_tensors);
@ -281,37 +271,22 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
while (num_calls_ > 0) {
cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
DCHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name("invocation_results.size"), invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
std::shared_ptr<InvocationResult> result = invocation_results_[i];
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("invocation_results[", i, "].size")),
result->return_values.size()));
for (size_t j = 0; j < result->return_values.size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(
strings::StrCat("invocation_results[", i, "][", j, "]")),
result->return_values[j]));
}
if (result->skip) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("invocation_results[", i, "].skip")),
""));
}
}
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("block_index"), block_index_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("cycle_index"), cycle_index_));
if (end_of_input_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("end_of_input"), ""));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("element_id_counter"),
element_id_counter_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("num_open"), num_open_));
TF_RETURN_IF_ERROR(WriteCurrentElements(writer));
TF_RETURN_IF_ERROR(WriteFutureElements(writer));
return Status::OK();
}
@ -319,200 +294,207 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
IteratorStateReader* reader) override {
mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name("invocation_results.size"), &invocation_results_size));
for (size_t i = 0; i < invocation_results_size; i++) {
std::shared_ptr<InvocationResult> result(new InvocationResult());
invocation_results_.push_back(result);
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
size_t num_return_values;
{
int64 size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("invocation_results[", i, "].size")),
&size));
num_return_values = static_cast<size_t>(size);
if (num_return_values != size) {
return errors::InvalidArgument(strings::StrCat(
full_name(
strings::StrCat("invocation_results[", i, "].size")),
": ", size, " is not a valid value of type size_t."));
}
}
result->return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
result->return_values.emplace_back();
TF_RETURN_IF_ERROR(
reader->ReadTensor(full_name(strings::StrCat(
"invocation_results[", i, "][", j, "]")),
&result->return_values.back()));
}
result->skip = reader->Contains(
full_name(strings::StrCat("invocation_results[", i, "].skip")));
result->notification.Notify();
}
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("block_index"), &block_index_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("cycle_index"), &cycle_index_));
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("element_id_counter"),
&element_id_counter_));
if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("num_open"), &num_open_));
TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader));
TF_RETURN_IF_ERROR(ReadFutureElements(ctx, reader));
return Status::OK();
}
private:
// Represents the result of fetching an element from a dataset.
struct Result {
Status status;
std::vector<Tensor> return_values;
// Indicates whether the result is ready to be consumed.
bool is_ready = false;
};
// The interleave transformation repeatedly inputs elements, applies the
// user-provided function to transform the input elements to datasets, and
// interleaves the elements of these datasets as its output.
//
// This structure represents an input element and derived state.
struct Element {
// Unique identifier, needed to support checkpointing.
int64 id;
// The actual input element.
std::vector<Tensor> inputs;
// Iterator created from the input element.
std::unique_ptr<IteratorBase> iterator;
std::vector<Tensor> inputs; // inputs for creating the iterator
bool in_use;
mutex mu;
// Buffer for storing the outputs of `iterator`.
std::deque<std::shared_ptr<Result>> results GUARDED_BY(mu);
// Indicates whether the element is used by a worker thread.
bool in_use = false;
};
struct InvocationResult {
Notification notification; // used for coordination with the consumer
Status status; // the invocation status
std::vector<Tensor> return_values; // the invocation result values
bool skip; // if set the result should be skipped
};
// Advances the position in the interleave cycle to the next cycle
// element.
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
block_index_ = 0;
cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
void EnsureRunnerThreadStarted(IteratorContext* ctx)
// Advances the position in the interleave cycle by one.
void AdvancePosition() EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
++block_index_;
if (block_index_ == dataset()->block_length_) {
AdvanceToNextInCycle();
}
}
// Consumes a result (if available), returning an indication of whether
// a result is available. If `true` is returned, `result` either
// points to a valid result or is null if end of input has been reached.
bool Consume(std::shared_ptr<Result>* result)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
{}, "tf_data_parallel_interleave_runner",
[this, new_ctx]() { RunnerThread(new_ctx); }));
if (!sloppy_) {
return ConsumeHelper(result);
}
// If we are allowed to be sloppy (i.e. return results out of order),
// try to find an element in the cycle that has a result available.
for (int i = 0; i < dataset()->cycle_length_; ++i) {
if (ConsumeHelper(result)) {
return true;
}
AdvanceToNextInCycle();
}
return false;
}
bool ConsumeHelper(std::shared_ptr<Result>* result)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
while (true) {
std::shared_ptr<Element> element = current_elements_[cycle_index_];
if (element) {
mutex_lock l(element->mu);
if (!element->results.empty()) {
if (element->results.front()->is_ready) {
// We found a result.
std::swap(*result, element->results.front());
element->results.pop_front();
AdvancePosition();
cond_var_->notify_all();
return true;
} else {
// Wait for the result to become ready.
return false;
}
} else if (!element->iterator) {
// We reached the end of input for this element. Reset
// it and move on to the next cycle element.
current_elements_[cycle_index_].reset();
AdvanceToNextInCycle();
cond_var_->notify_all();
continue;
} else {
// Wait for the iterator to produce a result.
return false;
}
} else {
if (!future_elements_.empty() || !end_of_input_) {
// Wait for an element to be created.
return false;
}
// No new elements will be created; try to find a
// non-empty element in the cycle.
for (int i = 0; i < dataset()->cycle_length_; ++i) {
AdvanceToNextInCycle();
if (current_elements_[cycle_index_]) {
break;
}
}
if (current_elements_[cycle_index_]) {
continue;
}
// End of input has been reached.
return true;
}
}
}
// Fetches up to `results.size()` outputs from the cycle element at
// position `cycle_index`.
// Manages current cycle elements, creating new iterators as needed and
// asynchronously fetching results from existing iterators.
//
// If end of input is encountered, the `skip` field of the invocation
// result is used to identify results that should be skipped.
void FetchOutputs(
const std::shared_ptr<IteratorContext>& ctx, IteratorBase* iterator,
int64 cycle_index,
const std::vector<std::shared_ptr<InvocationResult>>& results)
LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
bool end_of_input = false;
for (auto& result : results) {
if (!end_of_input) {
result->status = iterator->GetNext(
ctx.get(), &result->return_values, &end_of_input);
}
if (end_of_input) {
result->skip = true;
}
RecordBufferEnqueue(ctx.get(), result->return_values);
{
mutex_lock l(*mu_);
result->notification.Notify();
cond_var_->notify_all();
}
if (!result->status.ok()) {
break;
}
}
mutex_lock l(*mu_);
current_elements_[cycle_index].in_use = false;
if (end_of_input) {
// Release the ownership of the cycle element iterator, closing the
// iterator if end of input was encountered.
current_elements_[cycle_index].iterator.reset();
current_elements_[cycle_index].inputs.clear();
num_open_--;
}
num_calls_--;
const auto& stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator) {
stats_aggregator->AddScalar(
strings::StrCat(key_prefix_, "::thread_utilization"),
static_cast<float>(num_calls_) /
static_cast<float>(num_parallel_calls_->value));
}
cond_var_->notify_all();
}
// Method responsible for 1) creating iterators out of input elements, 2)
// determining the order in which elements are fetched from the iterators,
// and 3) scheduling the fetching of the elements to a threadpool.
//
// This method runs in the `runner_thread` background thread.
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
// This method runs in the `current_elements_manager_` background thread.
void CurrentElementsManager(const std::shared_ptr<IteratorContext>& ctx) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
return current_elements_[cycle_index_].in_use ||
num_calls_ >= num_parallel_calls_->value ||
invocation_results_.size() >=
dataset()->cycle_length_ * dataset()->block_length_;
const bool has_more_elements =
!future_elements_.empty() || !end_of_input_;
const int block_length = dataset()->block_length_;
bool all_elements_busy = true;
for (auto& element : current_elements_) {
if (!element) {
if (has_more_elements) {
all_elements_busy = false;
break;
}
} else {
mutex_lock l(element->mu);
if (!element->in_use && element->iterator &&
element->results.size() < block_length) {
all_elements_busy = false;
break;
}
}
}
return all_elements_busy || num_calls_ >= num_parallel_calls_->value;
};
while (true) {
mutex_lock l(*mu_);
// Wait until this thread is cancelled, the end of input has been
// reached, or the cycle element at the `cycle_index_` position is
// not in use and there is space in the `invocation_results_` queue.
// reached.
while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
RecordStop(ctx.get());
cond_var_->wait(l);
RecordStart(ctx.get());
}
if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
if (cancelled_ ||
(future_elements_.empty() && end_of_input_ && num_open_ == 0)) {
return;
}
while ((!end_of_input_ || num_open_ > 0) && !busy()) {
if (!current_elements_[cycle_index_].iterator) {
// Try to create a new iterator from the next input element.
Status status = input_impl_->GetNext(
ctx.get(), &current_elements_[cycle_index_].inputs,
&end_of_input_);
if (!status.ok()) {
invocation_results_.emplace_back(new InvocationResult());
std::shared_ptr<InvocationResult>& result =
invocation_results_.back();
result->status.Update(status);
result->notification.Notify();
break;
}
if (!end_of_input_) {
Status status = MakeIteratorFromInputElement(
ctx.get(), current_elements_[cycle_index_].inputs,
cycle_index_, *instantiated_captured_func_, prefix(),
&current_elements_[cycle_index_].iterator);
if (!status.ok()) {
invocation_results_.emplace_back(new InvocationResult());
std::shared_ptr<InvocationResult>& result =
invocation_results_.back();
result->status.Update(status);
result->notification.Notify();
break;
for (int i = 0; i < dataset()->cycle_length_; ++i) {
int idx = (cycle_index_ + i) % dataset()->cycle_length_;
if (!current_elements_[idx]) {
if (!future_elements_.empty()) {
current_elements_[idx] = std::move(future_elements_.back());
future_elements_.pop_back();
} else {
current_elements_[idx] = MakeElement(ctx);
if (!current_elements_[idx]) {
continue;
}
++num_open_;
}
}
if (current_elements_[cycle_index_].iterator) {
// Pre-allocate invocation results for outputs to be fetched
// and then fetch the outputs asynchronously.
std::vector<std::shared_ptr<InvocationResult>> results;
results.reserve(dataset()->block_length_);
for (int i = 0; i < dataset()->block_length_; ++i) {
invocation_results_.emplace_back(new InvocationResult());
results.push_back(invocation_results_.back());
std::shared_ptr<Element> element = current_elements_[idx];
if (!element->in_use && element->iterator) {
int64 num_results;
{
mutex_lock l(element->mu);
num_results =
dataset()->block_length_ - element->results.size();
}
if (num_results > 0) {
num_calls_++;
element->in_use = true;
thread_pool_->Schedule(
std::bind(&ParallelInterleaveIterator::FetchResults, this,
ctx, std::move(element), num_results));
}
num_calls_++;
current_elements_[cycle_index_].in_use = true;
thread_pool_->Schedule(
std::bind(&ParallelInterleaveIterator::FetchOutputs, this,
ctx, current_elements_[cycle_index_].iterator.get(),
cycle_index_, std::move(results)));
}
cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
const auto& stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator) {
@ -525,56 +507,178 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
}
// Determines whether the caller needs to wait for a result. Upon
// returning false, `result` will either be NULL if end of input has been
// reached or point to the result.
bool ShouldWait(std::shared_ptr<InvocationResult>* result)
void EnsureThreadsStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (sloppy_) {
for (auto it = invocation_results_.begin();
it != invocation_results_.end(); ++it) {
if ((*it)->notification.HasBeenNotified()) {
std::swap(*result, *it);
invocation_results_.erase(it);
cond_var_->notify_all();
return false;
}
}
return !invocation_results_.empty() ||
(!end_of_input_ || num_open_ > 0);
} else {
if (!invocation_results_.empty()) {
std::swap(*result, invocation_results_.front());
invocation_results_.pop_front();
cond_var_->notify_all();
return false;
}
return (!end_of_input_ || num_open_ > 0);
if (!current_elements_manager_) {
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
current_elements_manager_ =
WrapUnique<Thread>(ctx->env()->StartThread(
{}, "tf_data_parallel_interleave_current",
[this, new_ctx]() { CurrentElementsManager(new_ctx); }));
}
if (!future_elements_manager_) {
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
future_elements_manager_ = WrapUnique<Thread>(ctx->env()->StartThread(
{}, "tf_data_parallel_interleave_future",
[this, new_ctx]() { FutureElementsManager(new_ctx); }));
}
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
// Fetches up to `dataset()->block_length_` results from `element`.
void FetchResults(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<Element>& element,
int64 num_results) LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
bool end_of_input = false;
for (int64 i = 0; i < num_results; ++i) {
auto result = std::make_shared<Result>();
result->status = element->iterator->GetNext(
ctx.get(), &result->return_values, &end_of_input);
if (end_of_input) {
break;
}
RecordBufferEnqueue(ctx.get(), result->return_values);
mutex_lock l(*mu_);
mutex_lock l2(element->mu);
element->results.push_back(result);
result->is_ready = true;
cond_var_->notify_all();
}
mutex_lock l(*mu_);
// Release the ownership of the cycle element iterator.
element->in_use = false;
if (end_of_input) {
// Close the iterator if end of input was encountered.
element->iterator.reset();
element->inputs.clear();
--num_open_;
}
--num_calls_;
const auto& stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator) {
stats_aggregator->AddScalar(
strings::StrCat(key_prefix_, "::thread_utilization"),
static_cast<float>(num_calls_) /
static_cast<float>(num_parallel_calls_->value));
}
cond_var_->notify_all();
}
// Manages futures cycle elements, creating new iterators as needed and
// asynchronously fetching results from existing iterators.
//
// This method runs in the `future_elements_manager_` background thread.
void FutureElementsManager(const std::shared_ptr<IteratorContext>& ctx) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
return num_calls_ >= num_parallel_calls_->value ||
future_elements_.size() >= dataset()->cycle_length_;
};
while (true) {
mutex_lock l(*mu_);
// Wait until this thread is cancelled, the end of input has been
// reached, or the cycle element at the `cycle_index_` position is
// not in use.
while (!cancelled_ && !end_of_input_ && busy()) {
RecordStop(ctx.get());
cond_var_->wait(l);
RecordStart(ctx.get());
}
if (cancelled_ || end_of_input_) {
return;
}
while (!end_of_input_ && !busy()) {
std::shared_ptr<Element> element = MakeElement(ctx);
if (!element) {
break;
}
future_elements_.push_front(element);
if (!element->iterator) {
continue;
}
++num_calls_;
element->in_use = true;
thread_pool_->Schedule(
std::bind(&ParallelInterleaveIterator::FetchResults, this, ctx,
std::move(element), dataset()->block_length_));
}
const auto& stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator) {
stats_aggregator->AddScalar(
strings::StrCat(key_prefix_, "::thread_utilization"),
static_cast<float>(num_calls_) /
static_cast<float>(num_parallel_calls_->value));
}
cond_var_->notify_all();
}
}
// Creates a new element.
std::shared_ptr<Element> MakeElement(
const std::shared_ptr<IteratorContext>& ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
auto element = std::make_shared<Element>();
element->id = element_id_counter_++;
Status status =
input_impl_->GetNext(ctx.get(), &element->inputs, &end_of_input_);
if (!status.ok()) {
auto result = std::make_shared<Result>();
result->is_ready = true;
result->status = status;
mutex_lock l(element->mu);
element->results.push_back(std::move(result));
return element;
}
if (!end_of_input_) {
Status status = MakeIteratorFromInputElement(
ctx.get(), element->inputs, element->id,
*instantiated_captured_func_, prefix(), &element->iterator);
if (!status.ok()) {
auto result = std::make_shared<Result>();
result->is_ready = true;
result->status = status;
mutex_lock l(element->mu);
element->results.push_back(std::move(result));
return element;
}
++num_open_;
} else {
element.reset();
}
return element;
}
Status WriteStatusLocked(IteratorStateWriter* writer,
const string& key_prefix, size_t idx,
const Status& status)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
CodeKey(index), static_cast<int64>(status.code())));
CodeKey(key_prefix, idx), static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
status.error_message()));
TF_RETURN_IF_ERROR(writer->WriteScalar(
ErrorMessageKey(key_prefix, idx), status.error_message()));
}
return Status::OK();
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
Status ReadStatusLocked(IteratorStateReader* reader,
const string& key_prefix, size_t idx,
Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
TF_RETURN_IF_ERROR(
reader->ReadScalar(CodeKey(key_prefix, idx), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
string error_message;
TF_RETURN_IF_ERROR(
reader->ReadScalar(ErrorMessageKey(index), &error_message));
TF_RETURN_IF_ERROR(reader->ReadScalar(
ErrorMessageKey(key_prefix, idx), &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
@ -582,64 +686,178 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
string CodeKey(size_t index) {
string CodeKey(const string& key_prefix, size_t idx) {
return full_name(
strings::StrCat("invocation_results[", index, "].code"));
strings::StrCat(key_prefix, ".results[", idx, "].code"));
}
string ErrorMessageKey(size_t index) {
string ErrorMessageKey(const string& key_prefix, size_t idx) {
return full_name(
strings::StrCat("invocation_results[", index, "].error_message"));
strings::StrCat(key_prefix, ".results[", idx, "].error_message"));
}
Status WriteElement(std::shared_ptr<Element> element, int idx,
const string& key_prefix, IteratorStateWriter* writer)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (element->iterator) {
TF_RETURN_IF_ERROR(SaveInput(writer, element->iterator));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "].id")),
element->id));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "].inputs.size")),
element->inputs.size()));
for (int i = 0; i < element->inputs.size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(
strings::StrCat(key_prefix, "[", idx, "].inputs[", i, "]")),
element->inputs[i]));
}
}
mutex_lock l(element->mu);
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "].results.size")),
element->results.size()));
for (size_t i = 0; i < element->results.size(); i++) {
std::shared_ptr<Result> result = element->results[i];
TF_RETURN_IF_ERROR(WriteStatusLocked(
writer, strings::StrCat(key_prefix, "[", idx, "]"), i,
result->status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i,
"].size")),
result->return_values.size()));
for (size_t j = 0; j < result->return_values.size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i,
"][", j, "]")),
result->return_values[j]));
}
if (result->is_ready) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i,
"].is_ready")),
""));
}
}
return Status::OK();
}
Status WriteCurrentElements(IteratorStateWriter* writer)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name("current_elements.size"), current_elements_.size()));
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (current_elements_[idx].iterator) {
TF_RETURN_IF_ERROR(
SaveInput(writer, current_elements_[idx].iterator));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(
strings::StrCat("current_elements[", idx, "].inputs.size")),
current_elements_[idx].inputs.size()));
for (int i = 0; i < current_elements_[idx].inputs.size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("current_elements[", idx,
"].inputs[", i, "]")),
current_elements_[idx].inputs[i]));
}
if (current_elements_[idx]) {
TF_RETURN_IF_ERROR(WriteElement(current_elements_[idx], idx,
"current_elements", writer));
}
}
return Status::OK();
}
Status WriteFutureElements(IteratorStateWriter* writer)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name("future_elements.size"), future_elements_.size()));
for (int idx = 0; idx < future_elements_.size(); idx++) {
if (future_elements_[idx]) {
TF_RETURN_IF_ERROR(WriteElement(future_elements_[idx], idx,
"future_elements", writer));
}
}
return Status::OK();
}
Status ReadElement(IteratorContext* ctx, IteratorStateReader* reader,
int idx, const string& key_prefix,
std::shared_ptr<Element>* out)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!reader->Contains(full_name(
strings::StrCat(key_prefix, "[", idx, "].results.size")))) {
return Status::OK();
}
auto element = std::make_shared<Element>();
mutex_lock l(element->mu);
int64 results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "].results.size")),
&results_size));
element->results.resize(results_size);
for (size_t i = 0; i < results_size; i++) {
auto result = std::make_shared<Result>();
TF_RETURN_IF_ERROR(ReadStatusLocked(
reader, strings::StrCat(key_prefix, "[", idx, "]"), i,
&result->status));
int64 num_return_values;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i,
"].size")),
&num_return_values));
result->return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
result->return_values.emplace_back();
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i,
"][", j, "]")),
&result->return_values.back()));
}
result->is_ready = reader->Contains(full_name(strings::StrCat(
key_prefix, "[", idx, "].results[", i, "].is_ready")));
element->results[i] = std::move(result);
}
if (!reader->Contains(full_name(
strings::StrCat(key_prefix, "[", idx, "].inputs.size")))) {
element->iterator.reset();
*out = std::move(element);
return Status::OK();
}
int64 inputs_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "].inputs.size")),
&inputs_size));
element->inputs.resize(inputs_size);
for (int i = 0; i < inputs_size; i++) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(
strings::StrCat(key_prefix, "[", idx, "].inputs[", i, "]")),
&element->inputs[i]));
}
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(key_prefix, "[", idx, "].id")),
&element->id));
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, element->inputs, element->id,
*instantiated_captured_func_.get(), prefix(), &element->iterator));
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, element->iterator));
*out = std::move(element);
return Status::OK();
}
Status ReadCurrentElements(IteratorContext* ctx,
IteratorStateReader* reader)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("current_elements.size"), &size));
DCHECK_EQ(current_elements_.size(), size);
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (reader->Contains(full_name(strings::StrCat(
"current_elements[", idx, "].inputs.size")))) {
int64 inputs_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(
strings::StrCat("current_elements[", idx, "].inputs.size")),
&inputs_size));
current_elements_[idx].inputs.resize(inputs_size);
for (int i = 0; i < inputs_size; i++) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat("current_elements[", idx,
"].inputs[", i, "]")),
&current_elements_[idx].inputs[i]));
}
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, current_elements_[idx].inputs, idx,
*instantiated_captured_func_.get(), prefix(),
&current_elements_[idx].iterator));
TF_RETURN_IF_ERROR(
RestoreInput(ctx, reader, current_elements_[idx].iterator));
} else {
current_elements_[idx].iterator.reset();
}
TF_RETURN_IF_ERROR(ReadElement(ctx, reader, idx, "current_elements",
&current_elements_[idx]));
}
return Status::OK();
}
Status ReadFutureElements(IteratorContext* ctx,
IteratorStateReader* reader)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("future_elements.size"), &size));
future_elements_.resize(size);
for (int idx = 0; idx < future_elements_.size(); idx++) {
TF_RETURN_IF_ERROR(ReadElement(ctx, reader, idx, "future_elements",
&future_elements_[idx]));
}
return Status::OK();
}
@ -648,12 +866,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// the worker threads.
const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread, the runner thread, and
// the worker threads. In particular, the runner thread should only
// schedule new calls when the number of in-flight calls is less than the
// user specified level of parallelism, there are slots available in the
// `invocation_results_` buffer, the current cycle element is not in use,
// and there are elements left to be fetched.
// Used for coordination between the main thread, the manager threads, and
// the threadpool threads. In particular, the managers thread should only
// schedule new calls into the threadpool when the number of in-flight
// calls is less than the user specified level of parallelism and there
// are slots available in the element `results` buffer.
const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
@ -665,18 +882,17 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Iterator for input elements.
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(*mu_);
// Identifies current cycle element.
int64 cycle_index_ = 0;
// Identifies position in the interleave cycle.
int64 block_index_ GUARDED_BY(*mu_) = 0;
int64 cycle_index_ GUARDED_BY(*mu_) = 0;
// Iterators for the current cycle elements. Concurrent access is
// protected by `element_in_use_`.
std::vector<Element> current_elements_ GUARDED_BY(*mu_);
// Elements of the current interleave cycle.
std::vector<std::shared_ptr<Element>> current_elements_ GUARDED_BY(*mu_);
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
GUARDED_BY(*mu_);
// Elements to be used in the interleave cycle in the future.
std::deque<std::shared_ptr<Element>> future_elements_ GUARDED_BY(*mu_);
// Identifies whether end of input has been reached.
// Identifies whether the global end of input has been reached.
bool end_of_input_ GUARDED_BY(*mu_) = false;
// Identifies the number of open iterators.
@ -686,9 +902,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_;
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
std::unique_ptr<Thread> current_elements_manager_ GUARDED_BY(*mu_);
std::unique_ptr<Thread> future_elements_manager_ GUARDED_BY(*mu_);
int64 element_id_counter_ GUARDED_BY(*mu_) = 0;
// Identifies whether background activity should be cancelled.
// Identifies whether background threads should be cancelled.
bool cancelled_ GUARDED_BY(*mu_) = false;
string key_prefix_;
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;

View File

@ -17,19 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import threading_options
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
@ -78,47 +74,6 @@ def _interleave(lists, cycle_length, block_length):
break
def _make_coordinated_sloppy_dataset(input_values, cycle_length, block_length,
num_parallel_calls):
"""Produces a dataset iterator and events to control the order of elements.
Args:
input_values: the values to generate lists to interleave from
cycle_length: the length of the interleave cycle
block_length: the length of the interleave block
num_parallel_calls: the degree of interleave parallelism
Returns:
A dataset iterator (represented as `get_next` op) and events that can be
used to control the order of output elements.
"""
# Set up threading events used to sequence when items are produced that
# are subsequently interleaved. These events allow us to deterministically
# simulate slowdowns and force sloppiness.
coordination_events = {i: threading.Event() for i in input_values}
def map_py_fn(x):
coordination_events[x].wait()
coordination_events[x].clear()
return x * x
def map_fn(x):
return script_ops.py_func(map_py_fn, [x], x.dtype)
def interleave_fn(x):
dataset = dataset_ops.Dataset.from_tensors(x)
dataset = dataset.repeat(x)
return dataset.map(map_fn)
options = dataset_ops.Options()
options.experimental_deterministic = False
dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
2).interleave(interleave_fn, cycle_length, block_length,
num_parallel_calls).with_options(options)
return dataset, coordination_events
def _repeat(values, count):
"""Produces a list of lists suitable for testing interleave.
@ -252,63 +207,37 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
self.evaluate(get_next())
@parameterized.named_parameters(
("1", np.int64([4, 5, 6]), 2, 1, 1),
("2", np.int64([4, 5, 6]), 2, 1, 2),
("3", np.int64([4, 5, 6]), 2, 3, 1),
("4", np.int64([4, 5, 6]), 2, 3, 2),
("5", np.int64([4, 5, 6]), 3, 2, 1),
("6", np.int64([4, 5, 6]), 3, 2, 2),
("7", np.int64([4, 5, 6]), 3, 2, 3),
("8", np.int64([4, 0, 6]), 2, 3, 1),
("9", np.int64([4, 0, 6]), 2, 3, 2),
("1", np.int64([4, 5, 6]), 1, 3, 1),
("2", np.int64([4, 5, 6]), 2, 1, 1),
("3", np.int64([4, 5, 6]), 2, 1, 2),
("4", np.int64([4, 5, 6]), 2, 3, 1),
("5", np.int64([4, 5, 6]), 2, 3, 2),
("6", np.int64([4, 5, 6]), 7, 2, 1),
("7", np.int64([4, 5, 6]), 7, 2, 3),
("8", np.int64([4, 5, 6]), 7, 2, 5),
("9", np.int64([4, 5, 6]), 7, 2, 7),
("10", np.int64([4, 0, 6]), 2, 3, 1),
("11", np.int64([4, 0, 6]), 2, 3, 2),
)
def testSloppyInterleaveInOrder(self, input_values, cycle_length,
def testSloppyInterleaveDataset(self, input_values, cycle_length,
block_length, num_parallel_calls):
dataset, coordination_events = _make_coordinated_sloppy_dataset(
input_values, cycle_length, block_length, num_parallel_calls)
count = 2
dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
count).interleave(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
cycle_length, block_length, num_parallel_calls)
options = dataset_ops.Options()
options.experimental_threading = threading_options.ThreadingOptions()
options.experimental_threading.private_threadpool_size = (
num_parallel_calls + 1)
options.experimental_deterministic = False
dataset = dataset.with_options(options)
get_next = self.getNext(dataset, requires_initialization=True)
for expected_element in _interleave(
_repeat(input_values, 2), cycle_length, block_length):
coordination_events[expected_element].set()
self.assertEqual(expected_element * expected_element,
self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@parameterized.named_parameters(
("1", np.int64([4, 5, 6]), 2, 1, 2),
("2", np.int64([4, 5, 6]), 2, 3, 2),
("3", np.int64([4, 5, 6]), 3, 2, 3),
("4", np.int64([4, 0, 6]), 2, 3, 2),
)
def testSloppyInterleaveOutOfOrder(self, input_values, cycle_length,
block_length, num_parallel_calls):
dataset, coordination_events = _make_coordinated_sloppy_dataset(
input_values, cycle_length, block_length, num_parallel_calls)
options = dataset_ops.Options()
options.experimental_threading = threading_options.ThreadingOptions()
options.experimental_threading.private_threadpool_size = (
num_parallel_calls + 1)
dataset = dataset.with_options(options)
get_next = self.getNext(dataset, requires_initialization=True)
elements = [
x for x in _interleave(
_repeat(input_values, 2), cycle_length, block_length)
expected_output = [
element for element in _interleave(
_repeat(input_values, count), cycle_length, block_length)
]
for i in [1, 4, 7]:
elements[i], elements[i + 1] = elements[i + 1], elements[i]
for element in elements:
coordination_events[element].set()
self.assertEqual(element * element, self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
get_next = self.getNext(dataset)
actual_output = []
for _ in range(len(expected_output)):
actual_output.append(self.evaluate(get_next()))
self.assertAllEqual(expected_output.sort(), actual_output.sort())
if __name__ == "__main__":