Get rid of IteratorBase::is_exhausted flag since it is not possible to rely on it unless we lock each call to GetNext which is not preferable.

Each iterator now handles saving/restoring exhausted state.
As a guideline, we always reset the input_impl(s) when they get exhausted. This can be used as an indicator of exhausted-ness for non-terminal iterators. Also reduces memory overhead.
Each iterator should also handle calls to GetNextInternal when it is exhausted. Fixed this for some datasets.
Also fix a bug in dataset_serialization_test_base. We were not saving
a checkpoint after exhausting the iterator so verify_exhausted_iterator
was not really testing restoring an exhausted iterator.

PiperOrigin-RevId: 175253023
This commit is contained in:
Saurabh Saxena 2017-11-09 21:01:00 -08:00 committed by TensorFlower Gardener
parent badd356488
commit 3c41cb6bff
12 changed files with 120 additions and 57 deletions

View File

@ -337,11 +337,11 @@ class DatasetSerializationTestBase(test.TestCase):
num_iters = end - start
for _ in range(num_iters):
outputs.append(sess.run(get_next_op))
self._save(sess, saver)
ckpt_saved = True
if i == len(break_points) and verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self._save(sess, saver)
ckpt_saved = True
return outputs

View File

@ -143,9 +143,13 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
// Each row of `batch_elements` is a tuple of tensors from the
// input iterator.
std::vector<std::vector<Tensor>> batch_elements;
batch_elements.reserve(dataset()->batch_size_);
{
mutex_lock l(mu_);
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
batch_elements.reserve(dataset()->batch_size_);
*end_of_sequence = false;
for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
++i) {
@ -154,6 +158,8 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
end_of_sequence));
if (!*end_of_sequence) {
batch_elements.emplace_back(std::move(batch_element_tuple));
} else {
input_impl_.reset();
}
}
}
@ -194,14 +200,23 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
if (!input_impl_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
} else {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
}
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
if (!reader->Contains(full_name("input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
return Status::OK();
}

View File

@ -104,6 +104,10 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (i_ < 2) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@ -140,7 +144,9 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
} else if (i_ == 2) {
input_impl_.reset();
}
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
if (input_impl_) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
}
return Status::OK();
}

View File

@ -126,7 +126,6 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
MakeDataset(ctx, input, another_input, output);
}
const char IteratorBase::kIteratorExhausted[] = "ITERATOR_EXHAUSTED";
const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
"_DATASET_GRAPH_OUTPUT_NODE";

View File

@ -306,27 +306,14 @@ class IteratorBase {
// Saves the state of this iterator.
virtual Status Save(IteratorStateWriter* writer) {
if (is_exhausted_) {
LOG(INFO) << "Iterator exhausted.";
return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted);
} else {
return SaveInternal(writer);
}
return SaveInternal(writer);
}
// Restores the state of this iterator.
virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
if (reader->Contains(kIteratorExhausted)) {
LOG(INFO) << "Iterator exhausted. Nothing to restore.";
is_exhausted_ = true;
return Status::OK();
} else {
return RestoreInternal(ctx, reader);
}
return RestoreInternal(ctx, reader);
}
static const char kIteratorExhausted[];
protected:
// This is needed so that sub-classes of IteratorBase can call
// `SaveInternal` on their parent iterators, e.g., in
@ -354,8 +341,6 @@ class IteratorBase {
IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal");
}
bool is_exhausted_ = false; // Whether the iterator has been exhausted.
};
// Represents a (potentially infinite) range of outputs, where each
@ -491,10 +476,6 @@ class DatasetIterator : public IteratorBase {
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
port::Tracing::TraceMe activity(params_.prefix);
if (is_exhausted_) {
*end_of_sequence = true;
return Status::OK();
}
return GetNextInternal(ctx, out_tensors, end_of_sequence);
}

View File

@ -99,7 +99,6 @@ class RangeDatasetOp : public DatasetOpKernel {
if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) ||
(dataset()->step_ < 0 && next_ <= dataset()->stop_)) {
*end_of_sequence = true;
is_exhausted_ = true;
return Status::OK();
}
Tensor value_tensor(cpu_allocator(), DT_INT64, {});

View File

@ -402,7 +402,6 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
// Iteration ends when there are no more files to process.
if (current_file_index_ == dataset()->filenames_.size()) {
*end_of_sequence = true;
is_exhausted_ = true;
return Status::OK();
}

View File

@ -117,6 +117,10 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (i_ < dataset()->count_) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@ -127,7 +131,6 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
input_impl_ = dataset()->input_->MakeIterator(prefix());
}
*end_of_sequence = true;
is_exhausted_ = true;
input_impl_.reset();
return Status::OK();
}
@ -136,7 +139,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
if (!input_impl_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
} else {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
}
return Status::OK();
}
@ -144,7 +152,11 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
if (!reader->Contains(full_name("input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
return Status::OK();
}

View File

@ -105,8 +105,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
int64 start_micros = ctx->env()->NowMicros();
int64 num_log_entries = 0;
while (!end_of_input_sequence_ &&
buffer_.size() < dataset()->buffer_size_) {
while (input_impl_ && buffer_.size() < dataset()->buffer_size_) {
if (ctx->env()->NowMicros() >
((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
num_log_entries++;
@ -114,9 +113,10 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
<< buffer_.size() << " of " << dataset()->buffer_size_;
}
std::vector<Tensor> input_element;
bool end_of_input_sequence;
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
&end_of_input_sequence_));
if (!end_of_input_sequence_) {
&end_of_input_sequence));
if (!end_of_input_sequence) {
buffer_.emplace_back(std::move(input_element));
} else {
input_impl_.reset();
@ -135,7 +135,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
std::swap(buffer_[index], buffer_.back());
buffer_.pop_back();
} else {
DCHECK(end_of_input_sequence_);
DCHECK(input_impl_ == nullptr);
*end_of_sequence = true;
}
return Status::OK();
@ -148,11 +148,11 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
// Save the tensors in the buffer.
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("buffer_size"), buffer_.size()));
for (int i = 0; i < buffer_.size(); i++) {
for (size_t i = 0; i < buffer_.size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("buffer_", i, "_size")),
buffer_[i].size()));
for (int j = 0; j < buffer_[i].size(); j++) {
for (size_t j = 0; j < buffer_[i].size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("buffer_", i, "_", j)),
buffer_[i][j]));
@ -165,7 +165,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
// Save input iterator if it hasn't been exhausted else write
// "end_of_input_sequence".
if (end_of_input_sequence_) {
if (!input_impl_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("end_of_input_sequence"), ""));
} else {
@ -180,10 +180,15 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
buffer_.clear();
// Restore the buffer.
int64 buffer_size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("buffer_size"), &buffer_size));
for (int i = 0; i < buffer_size; i++) {
size_t buffer_size;
{
int64 temp;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("buffer_size"), &temp));
buffer_size = static_cast<size_t>(temp);
}
buffer_.reserve(buffer_size);
for (size_t i = 0; i < buffer_size; i++) {
int64 list_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("buffer_", i, "_size")), &list_size));
@ -205,7 +210,6 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
input_impl_ = dataset()->input_->MakeIterator(prefix());
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
end_of_input_sequence_ = true;
input_impl_.reset();
}
return Status::OK();
@ -230,7 +234,6 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
mutex mu_;
std::vector<std::vector<Tensor>> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
bool end_of_input_sequence_ GUARDED_BY(mu_) = false;
const int64 seed_ GUARDED_BY(mu_);
const int64 seed2_ GUARDED_BY(mu_);
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);

View File

@ -118,6 +118,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
// Keep calling GetNext(). TODO(vrv): Figure out a way to
// skip records without reading, perhaps by adding an
// interface to iterator.
@ -138,6 +143,9 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
// Return GetNext() on the underlying iterator.
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors,
end_of_sequence));
if (*end_of_sequence) {
input_impl_.reset();
}
return Status::OK();
}
@ -145,7 +153,12 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
if (input_impl_) {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
}
return Status::OK();
}
@ -153,7 +166,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
if (!reader->Contains(full_name("input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
return Status::OK();
}

View File

@ -118,6 +118,10 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (i_ < dataset()->count_) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@ -136,7 +140,12 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
if (input_impl_) {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
}
return Status::OK();
}
@ -144,7 +153,11 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
if (!reader->Contains(full_name("input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
return Status::OK();
}

View File

@ -109,6 +109,10 @@ class ZipDatasetOp : public DatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (input_impls_.empty()) {
*end_of_sequence = true;
return Status::OK();
}
out_tensors->clear();
out_tensors->reserve(dataset()->output_dtypes().size());
for (const auto& input_impl : input_impls_) {
@ -116,28 +120,43 @@ class ZipDatasetOp : public DatasetOpKernel {
TF_RETURN_IF_ERROR(
input_impl->GetNext(ctx, &input_tensors, end_of_sequence));
if (*end_of_sequence) {
return Status::OK();
break;
}
out_tensors->insert(out_tensors->end(), input_tensors.begin(),
input_tensors.end());
}
*end_of_sequence = false;
if (*end_of_sequence) {
out_tensors->clear();
input_impls_.clear();
} else {
*end_of_sequence = false;
}
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
for (auto& input_impl : input_impls_)
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl));
if (input_impls_.empty()) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impls_empty"), ""));
} else {
for (auto& input_impl : input_impls_)
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl));
}
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
for (auto& input_impl : input_impls_)
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl));
if (reader->Contains(full_name("input_impls_empty"))) {
input_impls_.clear();
} else {
DCHECK_EQ(input_impls_.size(), dataset()->inputs_.size());
for (auto& input_impl : input_impls_)
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl));
}
return Status::OK();
}