[tf.data] Adding prefetching of input iterators for parallel interleave.
PiperOrigin-RevId: 227911701
This commit is contained in:
parent
efe565bc09
commit
41e333f019
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <atomic>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -21,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
@ -43,12 +45,7 @@ namespace {
|
||||
//
|
||||
// Furthermore, this class favors modularity over extended functionality. In
|
||||
// particular, it refrains from implementing configurable buffering of output
|
||||
// elements and prefetching of input iterators, relying on other parts of
|
||||
// tf.data to provide this functionality if necessary.
|
||||
//
|
||||
// The above design choices were made with automated optimizations in mind,
|
||||
// isolating the degree of parallelism as the single tunable knob of this
|
||||
// implementation.
|
||||
// elements and prefetching of input iterators.
|
||||
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
|
||||
@ -237,27 +234,20 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
std::shared_ptr<InvocationResult> result;
|
||||
do {
|
||||
result.reset();
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
EnsureRunnerThreadStarted(ctx);
|
||||
while (ShouldWait(&result)) {
|
||||
RecordStop(ctx);
|
||||
cond_var_->wait(l);
|
||||
RecordStart(ctx);
|
||||
}
|
||||
if (!result) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
std::shared_ptr<Result> result;
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
EnsureThreadsStarted(ctx);
|
||||
while (!Consume(&result)) {
|
||||
RecordStop(ctx);
|
||||
cond_var_->wait(l);
|
||||
RecordStart(ctx);
|
||||
}
|
||||
RecordStop(ctx);
|
||||
result->notification.WaitForNotification();
|
||||
RecordStart(ctx);
|
||||
} while (result->skip);
|
||||
|
||||
}
|
||||
if (!result) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
if (result->status.ok()) {
|
||||
*out_tensors = std::move(result->return_values);
|
||||
RecordBufferDequeue(ctx, *out_tensors);
|
||||
@ -281,37 +271,22 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
while (num_calls_ > 0) {
|
||||
cond_var_->wait(l);
|
||||
}
|
||||
CHECK_EQ(num_calls_, 0);
|
||||
DCHECK_EQ(num_calls_, 0);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name("invocation_results.size"), invocation_results_.size()));
|
||||
for (size_t i = 0; i < invocation_results_.size(); i++) {
|
||||
std::shared_ptr<InvocationResult> result = invocation_results_[i];
|
||||
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat("invocation_results[", i, "].size")),
|
||||
result->return_values.size()));
|
||||
for (size_t j = 0; j < result->return_values.size(); j++) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(
|
||||
strings::StrCat("invocation_results[", i, "][", j, "]")),
|
||||
result->return_values[j]));
|
||||
}
|
||||
if (result->skip) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat("invocation_results[", i, "].skip")),
|
||||
""));
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("block_index"), block_index_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("cycle_index"), cycle_index_));
|
||||
if (end_of_input_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("end_of_input"), ""));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("element_id_counter"),
|
||||
element_id_counter_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("num_open"), num_open_));
|
||||
TF_RETURN_IF_ERROR(WriteCurrentElements(writer));
|
||||
TF_RETURN_IF_ERROR(WriteFutureElements(writer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -319,200 +294,207 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(*mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
int64 invocation_results_size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name("invocation_results.size"), &invocation_results_size));
|
||||
for (size_t i = 0; i < invocation_results_size; i++) {
|
||||
std::shared_ptr<InvocationResult> result(new InvocationResult());
|
||||
invocation_results_.push_back(result);
|
||||
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
|
||||
size_t num_return_values;
|
||||
{
|
||||
int64 size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat("invocation_results[", i, "].size")),
|
||||
&size));
|
||||
num_return_values = static_cast<size_t>(size);
|
||||
if (num_return_values != size) {
|
||||
return errors::InvalidArgument(strings::StrCat(
|
||||
full_name(
|
||||
strings::StrCat("invocation_results[", i, "].size")),
|
||||
": ", size, " is not a valid value of type size_t."));
|
||||
}
|
||||
}
|
||||
result->return_values.reserve(num_return_values);
|
||||
for (size_t j = 0; j < num_return_values; j++) {
|
||||
result->return_values.emplace_back();
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadTensor(full_name(strings::StrCat(
|
||||
"invocation_results[", i, "][", j, "]")),
|
||||
&result->return_values.back()));
|
||||
}
|
||||
result->skip = reader->Contains(
|
||||
full_name(strings::StrCat("invocation_results[", i, "].skip")));
|
||||
result->notification.Notify();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("block_index"), &block_index_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("cycle_index"), &cycle_index_));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("element_id_counter"),
|
||||
&element_id_counter_));
|
||||
if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("num_open"), &num_open_));
|
||||
TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader));
|
||||
TF_RETURN_IF_ERROR(ReadFutureElements(ctx, reader));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// Represents the result of fetching an element from a dataset.
|
||||
struct Result {
|
||||
Status status;
|
||||
std::vector<Tensor> return_values;
|
||||
// Indicates whether the result is ready to be consumed.
|
||||
bool is_ready = false;
|
||||
};
|
||||
|
||||
// The interleave transformation repeatedly inputs elements, applies the
|
||||
// user-provided function to transform the input elements to datasets, and
|
||||
// interleaves the elements of these datasets as its output.
|
||||
//
|
||||
// This structure represents an input element and derived state.
|
||||
struct Element {
|
||||
// Unique identifier, needed to support checkpointing.
|
||||
int64 id;
|
||||
// The actual input element.
|
||||
std::vector<Tensor> inputs;
|
||||
// Iterator created from the input element.
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
std::vector<Tensor> inputs; // inputs for creating the iterator
|
||||
bool in_use;
|
||||
mutex mu;
|
||||
// Buffer for storing the outputs of `iterator`.
|
||||
std::deque<std::shared_ptr<Result>> results GUARDED_BY(mu);
|
||||
// Indicates whether the element is used by a worker thread.
|
||||
bool in_use = false;
|
||||
};
|
||||
|
||||
struct InvocationResult {
|
||||
Notification notification; // used for coordination with the consumer
|
||||
Status status; // the invocation status
|
||||
std::vector<Tensor> return_values; // the invocation result values
|
||||
bool skip; // if set the result should be skipped
|
||||
};
|
||||
// Advances the position in the interleave cycle to the next cycle
|
||||
// element.
|
||||
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
block_index_ = 0;
|
||||
cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
|
||||
}
|
||||
|
||||
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
||||
// Advances the position in the interleave cycle by one.
|
||||
void AdvancePosition() EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
++block_index_;
|
||||
if (block_index_ == dataset()->block_length_) {
|
||||
AdvanceToNextInCycle();
|
||||
}
|
||||
}
|
||||
|
||||
// Consumes a result (if available), returning an indication of whether
|
||||
// a result is available. If `true` is returned, `result` either
|
||||
// points to a valid result or is null if end of input has been reached.
|
||||
bool Consume(std::shared_ptr<Result>* result)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!runner_thread_) {
|
||||
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
|
||||
runner_thread_.reset(ctx->env()->StartThread(
|
||||
{}, "tf_data_parallel_interleave_runner",
|
||||
[this, new_ctx]() { RunnerThread(new_ctx); }));
|
||||
if (!sloppy_) {
|
||||
return ConsumeHelper(result);
|
||||
}
|
||||
// If we are allowed to be sloppy (i.e. return results out of order),
|
||||
// try to find an element in the cycle that has a result available.
|
||||
for (int i = 0; i < dataset()->cycle_length_; ++i) {
|
||||
if (ConsumeHelper(result)) {
|
||||
return true;
|
||||
}
|
||||
AdvanceToNextInCycle();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConsumeHelper(std::shared_ptr<Result>* result)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
while (true) {
|
||||
std::shared_ptr<Element> element = current_elements_[cycle_index_];
|
||||
if (element) {
|
||||
mutex_lock l(element->mu);
|
||||
if (!element->results.empty()) {
|
||||
if (element->results.front()->is_ready) {
|
||||
// We found a result.
|
||||
std::swap(*result, element->results.front());
|
||||
element->results.pop_front();
|
||||
AdvancePosition();
|
||||
cond_var_->notify_all();
|
||||
return true;
|
||||
} else {
|
||||
// Wait for the result to become ready.
|
||||
return false;
|
||||
}
|
||||
} else if (!element->iterator) {
|
||||
// We reached the end of input for this element. Reset
|
||||
// it and move on to the next cycle element.
|
||||
current_elements_[cycle_index_].reset();
|
||||
AdvanceToNextInCycle();
|
||||
cond_var_->notify_all();
|
||||
continue;
|
||||
} else {
|
||||
// Wait for the iterator to produce a result.
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!future_elements_.empty() || !end_of_input_) {
|
||||
// Wait for an element to be created.
|
||||
return false;
|
||||
}
|
||||
// No new elements will be created; try to find a
|
||||
// non-empty element in the cycle.
|
||||
for (int i = 0; i < dataset()->cycle_length_; ++i) {
|
||||
AdvanceToNextInCycle();
|
||||
if (current_elements_[cycle_index_]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (current_elements_[cycle_index_]) {
|
||||
continue;
|
||||
}
|
||||
// End of input has been reached.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetches up to `results.size()` outputs from the cycle element at
|
||||
// position `cycle_index`.
|
||||
// Manages current cycle elements, creating new iterators as needed and
|
||||
// asynchronously fetching results from existing iterators.
|
||||
//
|
||||
// If end of input is encountered, the `skip` field of the invocation
|
||||
// result is used to identify results that should be skipped.
|
||||
void FetchOutputs(
|
||||
const std::shared_ptr<IteratorContext>& ctx, IteratorBase* iterator,
|
||||
int64 cycle_index,
|
||||
const std::vector<std::shared_ptr<InvocationResult>>& results)
|
||||
LOCKS_EXCLUDED(*mu_) {
|
||||
RecordStart(ctx.get());
|
||||
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
|
||||
bool end_of_input = false;
|
||||
for (auto& result : results) {
|
||||
if (!end_of_input) {
|
||||
result->status = iterator->GetNext(
|
||||
ctx.get(), &result->return_values, &end_of_input);
|
||||
}
|
||||
if (end_of_input) {
|
||||
result->skip = true;
|
||||
}
|
||||
RecordBufferEnqueue(ctx.get(), result->return_values);
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
result->notification.Notify();
|
||||
cond_var_->notify_all();
|
||||
}
|
||||
if (!result->status.ok()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
mutex_lock l(*mu_);
|
||||
current_elements_[cycle_index].in_use = false;
|
||||
if (end_of_input) {
|
||||
// Release the ownership of the cycle element iterator, closing the
|
||||
// iterator if end of input was encountered.
|
||||
current_elements_[cycle_index].iterator.reset();
|
||||
current_elements_[cycle_index].inputs.clear();
|
||||
num_open_--;
|
||||
}
|
||||
num_calls_--;
|
||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||
if (stats_aggregator) {
|
||||
stats_aggregator->AddScalar(
|
||||
strings::StrCat(key_prefix_, "::thread_utilization"),
|
||||
static_cast<float>(num_calls_) /
|
||||
static_cast<float>(num_parallel_calls_->value));
|
||||
}
|
||||
cond_var_->notify_all();
|
||||
}
|
||||
|
||||
// Method responsible for 1) creating iterators out of input elements, 2)
|
||||
// determining the order in which elements are fetched from the iterators,
|
||||
// and 3) scheduling the fetching of the elements to a threadpool.
|
||||
//
|
||||
// This method runs in the `runner_thread` background thread.
|
||||
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
|
||||
// This method runs in the `current_elements_manager_` background thread.
|
||||
void CurrentElementsManager(const std::shared_ptr<IteratorContext>& ctx) {
|
||||
RecordStart(ctx.get());
|
||||
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
|
||||
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
|
||||
return current_elements_[cycle_index_].in_use ||
|
||||
num_calls_ >= num_parallel_calls_->value ||
|
||||
invocation_results_.size() >=
|
||||
dataset()->cycle_length_ * dataset()->block_length_;
|
||||
const bool has_more_elements =
|
||||
!future_elements_.empty() || !end_of_input_;
|
||||
const int block_length = dataset()->block_length_;
|
||||
bool all_elements_busy = true;
|
||||
for (auto& element : current_elements_) {
|
||||
if (!element) {
|
||||
if (has_more_elements) {
|
||||
all_elements_busy = false;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
mutex_lock l(element->mu);
|
||||
if (!element->in_use && element->iterator &&
|
||||
element->results.size() < block_length) {
|
||||
all_elements_busy = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return all_elements_busy || num_calls_ >= num_parallel_calls_->value;
|
||||
};
|
||||
while (true) {
|
||||
mutex_lock l(*mu_);
|
||||
|
||||
// Wait until this thread is cancelled, the end of input has been
|
||||
// reached, or the cycle element at the `cycle_index_` position is
|
||||
// not in use and there is space in the `invocation_results_` queue.
|
||||
// reached.
|
||||
while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
|
||||
RecordStop(ctx.get());
|
||||
cond_var_->wait(l);
|
||||
RecordStart(ctx.get());
|
||||
}
|
||||
|
||||
if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
|
||||
if (cancelled_ ||
|
||||
(future_elements_.empty() && end_of_input_ && num_open_ == 0)) {
|
||||
return;
|
||||
}
|
||||
|
||||
while ((!end_of_input_ || num_open_ > 0) && !busy()) {
|
||||
if (!current_elements_[cycle_index_].iterator) {
|
||||
// Try to create a new iterator from the next input element.
|
||||
Status status = input_impl_->GetNext(
|
||||
ctx.get(), ¤t_elements_[cycle_index_].inputs,
|
||||
&end_of_input_);
|
||||
if (!status.ok()) {
|
||||
invocation_results_.emplace_back(new InvocationResult());
|
||||
std::shared_ptr<InvocationResult>& result =
|
||||
invocation_results_.back();
|
||||
result->status.Update(status);
|
||||
result->notification.Notify();
|
||||
break;
|
||||
}
|
||||
if (!end_of_input_) {
|
||||
Status status = MakeIteratorFromInputElement(
|
||||
ctx.get(), current_elements_[cycle_index_].inputs,
|
||||
cycle_index_, *instantiated_captured_func_, prefix(),
|
||||
¤t_elements_[cycle_index_].iterator);
|
||||
if (!status.ok()) {
|
||||
invocation_results_.emplace_back(new InvocationResult());
|
||||
std::shared_ptr<InvocationResult>& result =
|
||||
invocation_results_.back();
|
||||
result->status.Update(status);
|
||||
result->notification.Notify();
|
||||
break;
|
||||
for (int i = 0; i < dataset()->cycle_length_; ++i) {
|
||||
int idx = (cycle_index_ + i) % dataset()->cycle_length_;
|
||||
if (!current_elements_[idx]) {
|
||||
if (!future_elements_.empty()) {
|
||||
current_elements_[idx] = std::move(future_elements_.back());
|
||||
future_elements_.pop_back();
|
||||
} else {
|
||||
current_elements_[idx] = MakeElement(ctx);
|
||||
if (!current_elements_[idx]) {
|
||||
continue;
|
||||
}
|
||||
++num_open_;
|
||||
}
|
||||
}
|
||||
if (current_elements_[cycle_index_].iterator) {
|
||||
// Pre-allocate invocation results for outputs to be fetched
|
||||
// and then fetch the outputs asynchronously.
|
||||
std::vector<std::shared_ptr<InvocationResult>> results;
|
||||
results.reserve(dataset()->block_length_);
|
||||
for (int i = 0; i < dataset()->block_length_; ++i) {
|
||||
invocation_results_.emplace_back(new InvocationResult());
|
||||
results.push_back(invocation_results_.back());
|
||||
std::shared_ptr<Element> element = current_elements_[idx];
|
||||
if (!element->in_use && element->iterator) {
|
||||
int64 num_results;
|
||||
{
|
||||
mutex_lock l(element->mu);
|
||||
num_results =
|
||||
dataset()->block_length_ - element->results.size();
|
||||
}
|
||||
if (num_results > 0) {
|
||||
num_calls_++;
|
||||
element->in_use = true;
|
||||
thread_pool_->Schedule(
|
||||
std::bind(&ParallelInterleaveIterator::FetchResults, this,
|
||||
ctx, std::move(element), num_results));
|
||||
}
|
||||
num_calls_++;
|
||||
current_elements_[cycle_index_].in_use = true;
|
||||
thread_pool_->Schedule(
|
||||
std::bind(&ParallelInterleaveIterator::FetchOutputs, this,
|
||||
ctx, current_elements_[cycle_index_].iterator.get(),
|
||||
cycle_index_, std::move(results)));
|
||||
}
|
||||
cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
|
||||
}
|
||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||
if (stats_aggregator) {
|
||||
@ -525,56 +507,178 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
// Determines whether the caller needs to wait for a result. Upon
|
||||
// returning false, `result` will either be NULL if end of input has been
|
||||
// reached or point to the result.
|
||||
bool ShouldWait(std::shared_ptr<InvocationResult>* result)
|
||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (sloppy_) {
|
||||
for (auto it = invocation_results_.begin();
|
||||
it != invocation_results_.end(); ++it) {
|
||||
if ((*it)->notification.HasBeenNotified()) {
|
||||
std::swap(*result, *it);
|
||||
invocation_results_.erase(it);
|
||||
cond_var_->notify_all();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return !invocation_results_.empty() ||
|
||||
(!end_of_input_ || num_open_ > 0);
|
||||
} else {
|
||||
if (!invocation_results_.empty()) {
|
||||
std::swap(*result, invocation_results_.front());
|
||||
invocation_results_.pop_front();
|
||||
cond_var_->notify_all();
|
||||
return false;
|
||||
}
|
||||
return (!end_of_input_ || num_open_ > 0);
|
||||
if (!current_elements_manager_) {
|
||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
current_elements_manager_ =
|
||||
WrapUnique<Thread>(ctx->env()->StartThread(
|
||||
{}, "tf_data_parallel_interleave_current",
|
||||
[this, new_ctx]() { CurrentElementsManager(new_ctx); }));
|
||||
}
|
||||
if (!future_elements_manager_) {
|
||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
future_elements_manager_ = WrapUnique<Thread>(ctx->env()->StartThread(
|
||||
{}, "tf_data_parallel_interleave_future",
|
||||
[this, new_ctx]() { FutureElementsManager(new_ctx); }));
|
||||
}
|
||||
}
|
||||
|
||||
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
|
||||
// Fetches up to `dataset()->block_length_` results from `element`.
|
||||
void FetchResults(const std::shared_ptr<IteratorContext>& ctx,
|
||||
const std::shared_ptr<Element>& element,
|
||||
int64 num_results) LOCKS_EXCLUDED(*mu_) {
|
||||
RecordStart(ctx.get());
|
||||
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
|
||||
bool end_of_input = false;
|
||||
for (int64 i = 0; i < num_results; ++i) {
|
||||
auto result = std::make_shared<Result>();
|
||||
result->status = element->iterator->GetNext(
|
||||
ctx.get(), &result->return_values, &end_of_input);
|
||||
if (end_of_input) {
|
||||
break;
|
||||
}
|
||||
RecordBufferEnqueue(ctx.get(), result->return_values);
|
||||
mutex_lock l(*mu_);
|
||||
mutex_lock l2(element->mu);
|
||||
element->results.push_back(result);
|
||||
result->is_ready = true;
|
||||
cond_var_->notify_all();
|
||||
}
|
||||
|
||||
mutex_lock l(*mu_);
|
||||
// Release the ownership of the cycle element iterator.
|
||||
element->in_use = false;
|
||||
if (end_of_input) {
|
||||
// Close the iterator if end of input was encountered.
|
||||
element->iterator.reset();
|
||||
element->inputs.clear();
|
||||
--num_open_;
|
||||
}
|
||||
--num_calls_;
|
||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||
if (stats_aggregator) {
|
||||
stats_aggregator->AddScalar(
|
||||
strings::StrCat(key_prefix_, "::thread_utilization"),
|
||||
static_cast<float>(num_calls_) /
|
||||
static_cast<float>(num_parallel_calls_->value));
|
||||
}
|
||||
cond_var_->notify_all();
|
||||
}
|
||||
|
||||
// Manages futures cycle elements, creating new iterators as needed and
|
||||
// asynchronously fetching results from existing iterators.
|
||||
//
|
||||
// This method runs in the `future_elements_manager_` background thread.
|
||||
void FutureElementsManager(const std::shared_ptr<IteratorContext>& ctx) {
|
||||
RecordStart(ctx.get());
|
||||
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
|
||||
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
|
||||
return num_calls_ >= num_parallel_calls_->value ||
|
||||
future_elements_.size() >= dataset()->cycle_length_;
|
||||
};
|
||||
while (true) {
|
||||
mutex_lock l(*mu_);
|
||||
|
||||
// Wait until this thread is cancelled, the end of input has been
|
||||
// reached, or the cycle element at the `cycle_index_` position is
|
||||
// not in use.
|
||||
while (!cancelled_ && !end_of_input_ && busy()) {
|
||||
RecordStop(ctx.get());
|
||||
cond_var_->wait(l);
|
||||
RecordStart(ctx.get());
|
||||
}
|
||||
|
||||
if (cancelled_ || end_of_input_) {
|
||||
return;
|
||||
}
|
||||
|
||||
while (!end_of_input_ && !busy()) {
|
||||
std::shared_ptr<Element> element = MakeElement(ctx);
|
||||
if (!element) {
|
||||
break;
|
||||
}
|
||||
future_elements_.push_front(element);
|
||||
if (!element->iterator) {
|
||||
continue;
|
||||
}
|
||||
++num_calls_;
|
||||
element->in_use = true;
|
||||
thread_pool_->Schedule(
|
||||
std::bind(&ParallelInterleaveIterator::FetchResults, this, ctx,
|
||||
std::move(element), dataset()->block_length_));
|
||||
}
|
||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||
if (stats_aggregator) {
|
||||
stats_aggregator->AddScalar(
|
||||
strings::StrCat(key_prefix_, "::thread_utilization"),
|
||||
static_cast<float>(num_calls_) /
|
||||
static_cast<float>(num_parallel_calls_->value));
|
||||
}
|
||||
cond_var_->notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a new element.
|
||||
std::shared_ptr<Element> MakeElement(
|
||||
const std::shared_ptr<IteratorContext>& ctx)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
auto element = std::make_shared<Element>();
|
||||
element->id = element_id_counter_++;
|
||||
Status status =
|
||||
input_impl_->GetNext(ctx.get(), &element->inputs, &end_of_input_);
|
||||
if (!status.ok()) {
|
||||
auto result = std::make_shared<Result>();
|
||||
result->is_ready = true;
|
||||
result->status = status;
|
||||
mutex_lock l(element->mu);
|
||||
element->results.push_back(std::move(result));
|
||||
return element;
|
||||
}
|
||||
if (!end_of_input_) {
|
||||
Status status = MakeIteratorFromInputElement(
|
||||
ctx.get(), element->inputs, element->id,
|
||||
*instantiated_captured_func_, prefix(), &element->iterator);
|
||||
if (!status.ok()) {
|
||||
auto result = std::make_shared<Result>();
|
||||
result->is_ready = true;
|
||||
result->status = status;
|
||||
mutex_lock l(element->mu);
|
||||
element->results.push_back(std::move(result));
|
||||
return element;
|
||||
}
|
||||
++num_open_;
|
||||
} else {
|
||||
element.reset();
|
||||
}
|
||||
return element;
|
||||
}
|
||||
|
||||
Status WriteStatusLocked(IteratorStateWriter* writer,
|
||||
const string& key_prefix, size_t idx,
|
||||
const Status& status)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
CodeKey(index), static_cast<int64>(status.code())));
|
||||
CodeKey(key_prefix, idx), static_cast<int64>(status.code())));
|
||||
if (!status.ok()) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
|
||||
status.error_message()));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
ErrorMessageKey(key_prefix, idx), status.error_message()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
|
||||
Status ReadStatusLocked(IteratorStateReader* reader,
|
||||
const string& key_prefix, size_t idx,
|
||||
Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
int64 code_int;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(CodeKey(key_prefix, idx), &code_int));
|
||||
error::Code code = static_cast<error::Code>(code_int);
|
||||
|
||||
if (code != error::Code::OK) {
|
||||
string error_message;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(ErrorMessageKey(index), &error_message));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
ErrorMessageKey(key_prefix, idx), &error_message));
|
||||
*status = Status(code, error_message);
|
||||
} else {
|
||||
*status = Status::OK();
|
||||
@ -582,64 +686,178 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string CodeKey(size_t index) {
|
||||
string CodeKey(const string& key_prefix, size_t idx) {
|
||||
return full_name(
|
||||
strings::StrCat("invocation_results[", index, "].code"));
|
||||
strings::StrCat(key_prefix, ".results[", idx, "].code"));
|
||||
}
|
||||
|
||||
string ErrorMessageKey(size_t index) {
|
||||
string ErrorMessageKey(const string& key_prefix, size_t idx) {
|
||||
return full_name(
|
||||
strings::StrCat("invocation_results[", index, "].error_message"));
|
||||
strings::StrCat(key_prefix, ".results[", idx, "].error_message"));
|
||||
}
|
||||
|
||||
Status WriteElement(std::shared_ptr<Element> element, int idx,
|
||||
const string& key_prefix, IteratorStateWriter* writer)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (element->iterator) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, element->iterator));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(key_prefix, "[", idx, "].id")),
|
||||
element->id));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(key_prefix, "[", idx, "].inputs.size")),
|
||||
element->inputs.size()));
|
||||
for (int i = 0; i < element->inputs.size(); i++) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(
|
||||
strings::StrCat(key_prefix, "[", idx, "].inputs[", i, "]")),
|
||||
element->inputs[i]));
|
||||
}
|
||||
}
|
||||
mutex_lock l(element->mu);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(key_prefix, "[", idx, "].results.size")),
|
||||
element->results.size()));
|
||||
for (size_t i = 0; i < element->results.size(); i++) {
|
||||
std::shared_ptr<Result> result = element->results[i];
|
||||
TF_RETURN_IF_ERROR(WriteStatusLocked(
|
||||
writer, strings::StrCat(key_prefix, "[", idx, "]"), i,
|
||||
result->status));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i,
|
||||
"].size")),
|
||||
result->return_values.size()));
|
||||
for (size_t j = 0; j < result->return_values.size(); j++) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i,
|
||||
"][", j, "]")),
|
||||
result->return_values[j]));
|
||||
}
|
||||
if (result->is_ready) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i,
|
||||
"].is_ready")),
|
||||
""));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status WriteCurrentElements(IteratorStateWriter* writer)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name("current_elements.size"), current_elements_.size()));
|
||||
for (int idx = 0; idx < current_elements_.size(); idx++) {
|
||||
if (current_elements_[idx].iterator) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
SaveInput(writer, current_elements_[idx].iterator));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(
|
||||
strings::StrCat("current_elements[", idx, "].inputs.size")),
|
||||
current_elements_[idx].inputs.size()));
|
||||
for (int i = 0; i < current_elements_[idx].inputs.size(); i++) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(strings::StrCat("current_elements[", idx,
|
||||
"].inputs[", i, "]")),
|
||||
current_elements_[idx].inputs[i]));
|
||||
}
|
||||
if (current_elements_[idx]) {
|
||||
TF_RETURN_IF_ERROR(WriteElement(current_elements_[idx], idx,
|
||||
"current_elements", writer));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status WriteFutureElements(IteratorStateWriter* writer)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name("future_elements.size"), future_elements_.size()));
|
||||
for (int idx = 0; idx < future_elements_.size(); idx++) {
|
||||
if (future_elements_[idx]) {
|
||||
TF_RETURN_IF_ERROR(WriteElement(future_elements_[idx], idx,
|
||||
"future_elements", writer));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadElement(IteratorContext* ctx, IteratorStateReader* reader,
|
||||
int idx, const string& key_prefix,
|
||||
std::shared_ptr<Element>* out)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!reader->Contains(full_name(
|
||||
strings::StrCat(key_prefix, "[", idx, "].results.size")))) {
|
||||
return Status::OK();
|
||||
}
|
||||
auto element = std::make_shared<Element>();
|
||||
mutex_lock l(element->mu);
|
||||
int64 results_size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat(key_prefix, "[", idx, "].results.size")),
|
||||
&results_size));
|
||||
element->results.resize(results_size);
|
||||
for (size_t i = 0; i < results_size; i++) {
|
||||
auto result = std::make_shared<Result>();
|
||||
TF_RETURN_IF_ERROR(ReadStatusLocked(
|
||||
reader, strings::StrCat(key_prefix, "[", idx, "]"), i,
|
||||
&result->status));
|
||||
int64 num_return_values;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i,
|
||||
"].size")),
|
||||
&num_return_values));
|
||||
result->return_values.reserve(num_return_values);
|
||||
for (size_t j = 0; j < num_return_values; j++) {
|
||||
result->return_values.emplace_back();
|
||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||
full_name(strings::StrCat(key_prefix, "[", idx, "].results[", i,
|
||||
"][", j, "]")),
|
||||
&result->return_values.back()));
|
||||
}
|
||||
result->is_ready = reader->Contains(full_name(strings::StrCat(
|
||||
key_prefix, "[", idx, "].results[", i, "].is_ready")));
|
||||
element->results[i] = std::move(result);
|
||||
}
|
||||
if (!reader->Contains(full_name(
|
||||
strings::StrCat(key_prefix, "[", idx, "].inputs.size")))) {
|
||||
element->iterator.reset();
|
||||
*out = std::move(element);
|
||||
return Status::OK();
|
||||
}
|
||||
int64 inputs_size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat(key_prefix, "[", idx, "].inputs.size")),
|
||||
&inputs_size));
|
||||
element->inputs.resize(inputs_size);
|
||||
for (int i = 0; i < inputs_size; i++) {
|
||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||
full_name(
|
||||
strings::StrCat(key_prefix, "[", idx, "].inputs[", i, "]")),
|
||||
&element->inputs[i]));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat(key_prefix, "[", idx, "].id")),
|
||||
&element->id));
|
||||
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
|
||||
ctx, element->inputs, element->id,
|
||||
*instantiated_captured_func_.get(), prefix(), &element->iterator));
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, element->iterator));
|
||||
*out = std::move(element);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadCurrentElements(IteratorContext* ctx,
|
||||
IteratorStateReader* reader)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
int64 size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("current_elements.size"), &size));
|
||||
DCHECK_EQ(current_elements_.size(), size);
|
||||
for (int idx = 0; idx < current_elements_.size(); idx++) {
|
||||
if (reader->Contains(full_name(strings::StrCat(
|
||||
"current_elements[", idx, "].inputs.size")))) {
|
||||
int64 inputs_size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(
|
||||
strings::StrCat("current_elements[", idx, "].inputs.size")),
|
||||
&inputs_size));
|
||||
current_elements_[idx].inputs.resize(inputs_size);
|
||||
for (int i = 0; i < inputs_size; i++) {
|
||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||
full_name(strings::StrCat("current_elements[", idx,
|
||||
"].inputs[", i, "]")),
|
||||
¤t_elements_[idx].inputs[i]));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
|
||||
ctx, current_elements_[idx].inputs, idx,
|
||||
*instantiated_captured_func_.get(), prefix(),
|
||||
¤t_elements_[idx].iterator));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestoreInput(ctx, reader, current_elements_[idx].iterator));
|
||||
} else {
|
||||
current_elements_[idx].iterator.reset();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ReadElement(ctx, reader, idx, "current_elements",
|
||||
¤t_elements_[idx]));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadFutureElements(IteratorContext* ctx,
|
||||
IteratorStateReader* reader)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
int64 size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("future_elements.size"), &size));
|
||||
future_elements_.resize(size);
|
||||
for (int idx = 0; idx < future_elements_.size(); idx++) {
|
||||
TF_RETURN_IF_ERROR(ReadElement(ctx, reader, idx, "future_elements",
|
||||
&future_elements_[idx]));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -648,12 +866,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
// the worker threads.
|
||||
const std::shared_ptr<mutex> mu_;
|
||||
|
||||
// Used for coordination between the main thread, the runner thread, and
|
||||
// the worker threads. In particular, the runner thread should only
|
||||
// schedule new calls when the number of in-flight calls is less than the
|
||||
// user specified level of parallelism, there are slots available in the
|
||||
// `invocation_results_` buffer, the current cycle element is not in use,
|
||||
// and there are elements left to be fetched.
|
||||
// Used for coordination between the main thread, the manager threads, and
|
||||
// the threadpool threads. In particular, the managers thread should only
|
||||
// schedule new calls into the threadpool when the number of in-flight
|
||||
// calls is less than the user specified level of parallelism and there
|
||||
// are slots available in the element `results` buffer.
|
||||
const std::shared_ptr<condition_variable> cond_var_;
|
||||
|
||||
// Identifies the maximum number of parallel calls.
|
||||
@ -665,18 +882,17 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Iterator for input elements.
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(*mu_);
|
||||
|
||||
// Identifies current cycle element.
|
||||
int64 cycle_index_ = 0;
|
||||
// Identifies position in the interleave cycle.
|
||||
int64 block_index_ GUARDED_BY(*mu_) = 0;
|
||||
int64 cycle_index_ GUARDED_BY(*mu_) = 0;
|
||||
|
||||
// Iterators for the current cycle elements. Concurrent access is
|
||||
// protected by `element_in_use_`.
|
||||
std::vector<Element> current_elements_ GUARDED_BY(*mu_);
|
||||
// Elements of the current interleave cycle.
|
||||
std::vector<std::shared_ptr<Element>> current_elements_ GUARDED_BY(*mu_);
|
||||
|
||||
// Buffer for storing the invocation results.
|
||||
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
|
||||
GUARDED_BY(*mu_);
|
||||
// Elements to be used in the interleave cycle in the future.
|
||||
std::deque<std::shared_ptr<Element>> future_elements_ GUARDED_BY(*mu_);
|
||||
|
||||
// Identifies whether end of input has been reached.
|
||||
// Identifies whether the global end of input has been reached.
|
||||
bool end_of_input_ GUARDED_BY(*mu_) = false;
|
||||
|
||||
// Identifies the number of open iterators.
|
||||
@ -686,9 +902,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
int64 num_calls_ GUARDED_BY(*mu_) = 0;
|
||||
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
|
||||
std::unique_ptr<Thread> current_elements_manager_ GUARDED_BY(*mu_);
|
||||
std::unique_ptr<Thread> future_elements_manager_ GUARDED_BY(*mu_);
|
||||
int64 element_id_counter_ GUARDED_BY(*mu_) = 0;
|
||||
|
||||
// Identifies whether background activity should be cancelled.
|
||||
// Identifies whether background threads should be cancelled.
|
||||
bool cancelled_ GUARDED_BY(*mu_) = false;
|
||||
string key_prefix_;
|
||||
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
|
||||
|
@ -17,19 +17,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import threading_options
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -78,47 +74,6 @@ def _interleave(lists, cycle_length, block_length):
|
||||
break
|
||||
|
||||
|
||||
def _make_coordinated_sloppy_dataset(input_values, cycle_length, block_length,
|
||||
num_parallel_calls):
|
||||
"""Produces a dataset iterator and events to control the order of elements.
|
||||
|
||||
Args:
|
||||
input_values: the values to generate lists to interleave from
|
||||
cycle_length: the length of the interleave cycle
|
||||
block_length: the length of the interleave block
|
||||
num_parallel_calls: the degree of interleave parallelism
|
||||
|
||||
Returns:
|
||||
A dataset iterator (represented as `get_next` op) and events that can be
|
||||
used to control the order of output elements.
|
||||
"""
|
||||
|
||||
# Set up threading events used to sequence when items are produced that
|
||||
# are subsequently interleaved. These events allow us to deterministically
|
||||
# simulate slowdowns and force sloppiness.
|
||||
coordination_events = {i: threading.Event() for i in input_values}
|
||||
|
||||
def map_py_fn(x):
|
||||
coordination_events[x].wait()
|
||||
coordination_events[x].clear()
|
||||
return x * x
|
||||
|
||||
def map_fn(x):
|
||||
return script_ops.py_func(map_py_fn, [x], x.dtype)
|
||||
|
||||
def interleave_fn(x):
|
||||
dataset = dataset_ops.Dataset.from_tensors(x)
|
||||
dataset = dataset.repeat(x)
|
||||
return dataset.map(map_fn)
|
||||
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_deterministic = False
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
|
||||
2).interleave(interleave_fn, cycle_length, block_length,
|
||||
num_parallel_calls).with_options(options)
|
||||
return dataset, coordination_events
|
||||
|
||||
|
||||
def _repeat(values, count):
|
||||
"""Produces a list of lists suitable for testing interleave.
|
||||
|
||||
@ -252,63 +207,37 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", np.int64([4, 5, 6]), 2, 1, 1),
|
||||
("2", np.int64([4, 5, 6]), 2, 1, 2),
|
||||
("3", np.int64([4, 5, 6]), 2, 3, 1),
|
||||
("4", np.int64([4, 5, 6]), 2, 3, 2),
|
||||
("5", np.int64([4, 5, 6]), 3, 2, 1),
|
||||
("6", np.int64([4, 5, 6]), 3, 2, 2),
|
||||
("7", np.int64([4, 5, 6]), 3, 2, 3),
|
||||
("8", np.int64([4, 0, 6]), 2, 3, 1),
|
||||
("9", np.int64([4, 0, 6]), 2, 3, 2),
|
||||
("1", np.int64([4, 5, 6]), 1, 3, 1),
|
||||
("2", np.int64([4, 5, 6]), 2, 1, 1),
|
||||
("3", np.int64([4, 5, 6]), 2, 1, 2),
|
||||
("4", np.int64([4, 5, 6]), 2, 3, 1),
|
||||
("5", np.int64([4, 5, 6]), 2, 3, 2),
|
||||
("6", np.int64([4, 5, 6]), 7, 2, 1),
|
||||
("7", np.int64([4, 5, 6]), 7, 2, 3),
|
||||
("8", np.int64([4, 5, 6]), 7, 2, 5),
|
||||
("9", np.int64([4, 5, 6]), 7, 2, 7),
|
||||
("10", np.int64([4, 0, 6]), 2, 3, 1),
|
||||
("11", np.int64([4, 0, 6]), 2, 3, 2),
|
||||
)
|
||||
def testSloppyInterleaveInOrder(self, input_values, cycle_length,
|
||||
def testSloppyInterleaveDataset(self, input_values, cycle_length,
|
||||
block_length, num_parallel_calls):
|
||||
dataset, coordination_events = _make_coordinated_sloppy_dataset(
|
||||
input_values, cycle_length, block_length, num_parallel_calls)
|
||||
count = 2
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
|
||||
count).interleave(
|
||||
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
|
||||
cycle_length, block_length, num_parallel_calls)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_threading = threading_options.ThreadingOptions()
|
||||
options.experimental_threading.private_threadpool_size = (
|
||||
num_parallel_calls + 1)
|
||||
options.experimental_deterministic = False
|
||||
dataset = dataset.with_options(options)
|
||||
|
||||
get_next = self.getNext(dataset, requires_initialization=True)
|
||||
for expected_element in _interleave(
|
||||
_repeat(input_values, 2), cycle_length, block_length):
|
||||
coordination_events[expected_element].set()
|
||||
self.assertEqual(expected_element * expected_element,
|
||||
self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", np.int64([4, 5, 6]), 2, 1, 2),
|
||||
("2", np.int64([4, 5, 6]), 2, 3, 2),
|
||||
("3", np.int64([4, 5, 6]), 3, 2, 3),
|
||||
("4", np.int64([4, 0, 6]), 2, 3, 2),
|
||||
)
|
||||
def testSloppyInterleaveOutOfOrder(self, input_values, cycle_length,
|
||||
block_length, num_parallel_calls):
|
||||
dataset, coordination_events = _make_coordinated_sloppy_dataset(
|
||||
input_values, cycle_length, block_length, num_parallel_calls)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_threading = threading_options.ThreadingOptions()
|
||||
options.experimental_threading.private_threadpool_size = (
|
||||
num_parallel_calls + 1)
|
||||
dataset = dataset.with_options(options)
|
||||
get_next = self.getNext(dataset, requires_initialization=True)
|
||||
elements = [
|
||||
x for x in _interleave(
|
||||
_repeat(input_values, 2), cycle_length, block_length)
|
||||
expected_output = [
|
||||
element for element in _interleave(
|
||||
_repeat(input_values, count), cycle_length, block_length)
|
||||
]
|
||||
for i in [1, 4, 7]:
|
||||
elements[i], elements[i + 1] = elements[i + 1], elements[i]
|
||||
|
||||
for element in elements:
|
||||
coordination_events[element].set()
|
||||
self.assertEqual(element * element, self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
get_next = self.getNext(dataset)
|
||||
actual_output = []
|
||||
for _ in range(len(expected_output)):
|
||||
actual_output.append(self.evaluate(get_next()))
|
||||
self.assertAllEqual(expected_output.sort(), actual_output.sort())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user