Get rid of IteratorBase::is_exhausted flag since it is not possible to rely on it unless we lock each call to GetNext which is not preferable.
Each iterator now handles saving/restoring exhausted state. As a guideline, we always reset the input_impl(s) when they get exhausted. This can be used as an indicator of exhausted-ness for non-terminal iterators. Also reduces memory overhead. Each iterator should also handle calls to GetNextInternal when it is exhausted. Fixed this for some datasets. Also fix a bug in dataset_serialization_test_base. We were not saving a checkpoint after exhausting the iterator so verify_exhausted_iterator was not really testing restoring an exhausted iterator. PiperOrigin-RevId: 175253023
This commit is contained in:
parent
badd356488
commit
3c41cb6bff
@ -337,11 +337,11 @@ class DatasetSerializationTestBase(test.TestCase):
|
||||
num_iters = end - start
|
||||
for _ in range(num_iters):
|
||||
outputs.append(sess.run(get_next_op))
|
||||
self._save(sess, saver)
|
||||
ckpt_saved = True
|
||||
if i == len(break_points) and verify_exhausted:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self._save(sess, saver)
|
||||
ckpt_saved = True
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -143,9 +143,13 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Each row of `batch_elements` is a tuple of tensors from the
|
||||
// input iterator.
|
||||
std::vector<std::vector<Tensor>> batch_elements;
|
||||
batch_elements.reserve(dataset()->batch_size_);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!input_impl_) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
batch_elements.reserve(dataset()->batch_size_);
|
||||
*end_of_sequence = false;
|
||||
for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
|
||||
++i) {
|
||||
@ -154,6 +158,8 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
end_of_sequence));
|
||||
if (!*end_of_sequence) {
|
||||
batch_elements.emplace_back(std::move(batch_element_tuple));
|
||||
} else {
|
||||
input_impl_.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -194,14 +200,23 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
if (!reader->Contains(full_name("input_impl_empty"))) {
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
} else {
|
||||
input_impl_.reset();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -104,6 +104,10 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!input_impl_) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
while (i_ < 2) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
||||
@ -140,7 +144,9 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
|
||||
} else if (i_ == 2) {
|
||||
input_impl_.reset();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -126,7 +126,6 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
|
||||
MakeDataset(ctx, input, another_input, output);
|
||||
}
|
||||
|
||||
const char IteratorBase::kIteratorExhausted[] = "ITERATOR_EXHAUSTED";
|
||||
const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
|
||||
const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
|
||||
"_DATASET_GRAPH_OUTPUT_NODE";
|
||||
|
@ -306,27 +306,14 @@ class IteratorBase {
|
||||
|
||||
// Saves the state of this iterator.
|
||||
virtual Status Save(IteratorStateWriter* writer) {
|
||||
if (is_exhausted_) {
|
||||
LOG(INFO) << "Iterator exhausted.";
|
||||
return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted);
|
||||
} else {
|
||||
return SaveInternal(writer);
|
||||
}
|
||||
return SaveInternal(writer);
|
||||
}
|
||||
|
||||
// Restores the state of this iterator.
|
||||
virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
|
||||
if (reader->Contains(kIteratorExhausted)) {
|
||||
LOG(INFO) << "Iterator exhausted. Nothing to restore.";
|
||||
is_exhausted_ = true;
|
||||
return Status::OK();
|
||||
} else {
|
||||
return RestoreInternal(ctx, reader);
|
||||
}
|
||||
return RestoreInternal(ctx, reader);
|
||||
}
|
||||
|
||||
static const char kIteratorExhausted[];
|
||||
|
||||
protected:
|
||||
// This is needed so that sub-classes of IteratorBase can call
|
||||
// `SaveInternal` on their parent iterators, e.g., in
|
||||
@ -354,8 +341,6 @@ class IteratorBase {
|
||||
IteratorStateReader* reader) {
|
||||
return errors::Unimplemented("RestoreInternal");
|
||||
}
|
||||
|
||||
bool is_exhausted_ = false; // Whether the iterator has been exhausted.
|
||||
};
|
||||
|
||||
// Represents a (potentially infinite) range of outputs, where each
|
||||
@ -491,10 +476,6 @@ class DatasetIterator : public IteratorBase {
|
||||
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) final {
|
||||
port::Tracing::TraceMe activity(params_.prefix);
|
||||
if (is_exhausted_) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
return GetNextInternal(ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
|
@ -99,7 +99,6 @@ class RangeDatasetOp : public DatasetOpKernel {
|
||||
if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) ||
|
||||
(dataset()->step_ < 0 && next_ <= dataset()->stop_)) {
|
||||
*end_of_sequence = true;
|
||||
is_exhausted_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
Tensor value_tensor(cpu_allocator(), DT_INT64, {});
|
||||
|
@ -402,7 +402,6 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
|
||||
// Iteration ends when there are no more files to process.
|
||||
if (current_file_index_ == dataset()->filenames_.size()) {
|
||||
*end_of_sequence = true;
|
||||
is_exhausted_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -117,6 +117,10 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
|
||||
if (!input_impl_) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
while (i_ < dataset()->count_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
||||
@ -127,7 +131,6 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
|
||||
input_impl_ = dataset()->input_->MakeIterator(prefix());
|
||||
}
|
||||
*end_of_sequence = true;
|
||||
is_exhausted_ = true;
|
||||
input_impl_.reset();
|
||||
return Status::OK();
|
||||
}
|
||||
@ -136,7 +139,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -144,7 +152,11 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
if (!reader->Contains(full_name("input_impl_empty"))) {
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
} else {
|
||||
input_impl_.reset();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -105,8 +105,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
mutex_lock l(mu_);
|
||||
int64 start_micros = ctx->env()->NowMicros();
|
||||
int64 num_log_entries = 0;
|
||||
while (!end_of_input_sequence_ &&
|
||||
buffer_.size() < dataset()->buffer_size_) {
|
||||
while (input_impl_ && buffer_.size() < dataset()->buffer_size_) {
|
||||
if (ctx->env()->NowMicros() >
|
||||
((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
|
||||
num_log_entries++;
|
||||
@ -114,9 +113,10 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
<< buffer_.size() << " of " << dataset()->buffer_size_;
|
||||
}
|
||||
std::vector<Tensor> input_element;
|
||||
bool end_of_input_sequence;
|
||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
|
||||
&end_of_input_sequence_));
|
||||
if (!end_of_input_sequence_) {
|
||||
&end_of_input_sequence));
|
||||
if (!end_of_input_sequence) {
|
||||
buffer_.emplace_back(std::move(input_element));
|
||||
} else {
|
||||
input_impl_.reset();
|
||||
@ -135,7 +135,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::swap(buffer_[index], buffer_.back());
|
||||
buffer_.pop_back();
|
||||
} else {
|
||||
DCHECK(end_of_input_sequence_);
|
||||
DCHECK(input_impl_ == nullptr);
|
||||
*end_of_sequence = true;
|
||||
}
|
||||
return Status::OK();
|
||||
@ -148,11 +148,11 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Save the tensors in the buffer.
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("buffer_size"), buffer_.size()));
|
||||
for (int i = 0; i < buffer_.size(); i++) {
|
||||
for (size_t i = 0; i < buffer_.size(); i++) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat("buffer_", i, "_size")),
|
||||
buffer_[i].size()));
|
||||
for (int j = 0; j < buffer_[i].size(); j++) {
|
||||
for (size_t j = 0; j < buffer_[i].size(); j++) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(strings::StrCat("buffer_", i, "_", j)),
|
||||
buffer_[i][j]));
|
||||
@ -165,7 +165,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
// Save input iterator if it hasn't been exhausted else write
|
||||
// "end_of_input_sequence".
|
||||
if (end_of_input_sequence_) {
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("end_of_input_sequence"), ""));
|
||||
} else {
|
||||
@ -180,10 +180,15 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
buffer_.clear();
|
||||
|
||||
// Restore the buffer.
|
||||
int64 buffer_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("buffer_size"), &buffer_size));
|
||||
for (int i = 0; i < buffer_size; i++) {
|
||||
size_t buffer_size;
|
||||
{
|
||||
int64 temp;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("buffer_size"), &temp));
|
||||
buffer_size = static_cast<size_t>(temp);
|
||||
}
|
||||
buffer_.reserve(buffer_size);
|
||||
for (size_t i = 0; i < buffer_size; i++) {
|
||||
int64 list_size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat("buffer_", i, "_size")), &list_size));
|
||||
@ -205,7 +210,6 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
input_impl_ = dataset()->input_->MakeIterator(prefix());
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
} else {
|
||||
end_of_input_sequence_ = true;
|
||||
input_impl_.reset();
|
||||
}
|
||||
return Status::OK();
|
||||
@ -230,7 +234,6 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
mutex mu_;
|
||||
std::vector<std::vector<Tensor>> buffer_ GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
bool end_of_input_sequence_ GUARDED_BY(mu_) = false;
|
||||
const int64 seed_ GUARDED_BY(mu_);
|
||||
const int64 seed2_ GUARDED_BY(mu_);
|
||||
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
|
||||
|
@ -118,6 +118,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
|
||||
|
||||
if (!input_impl_) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Keep calling GetNext(). TODO(vrv): Figure out a way to
|
||||
// skip records without reading, perhaps by adding an
|
||||
// interface to iterator.
|
||||
@ -138,6 +143,9 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Return GetNext() on the underlying iterator.
|
||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors,
|
||||
end_of_sequence));
|
||||
if (*end_of_sequence) {
|
||||
input_impl_.reset();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -145,7 +153,12 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -153,7 +166,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
if (!reader->Contains(full_name("input_impl_empty"))) {
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
} else {
|
||||
input_impl_.reset();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -118,6 +118,10 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
|
||||
if (!input_impl_) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
while (i_ < dataset()->count_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
||||
@ -136,7 +140,12 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -144,7 +153,11 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
if (!reader->Contains(full_name("input_impl_empty"))) {
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
} else {
|
||||
input_impl_.reset();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -109,6 +109,10 @@ class ZipDatasetOp : public DatasetOpKernel {
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
if (input_impls_.empty()) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
out_tensors->clear();
|
||||
out_tensors->reserve(dataset()->output_dtypes().size());
|
||||
for (const auto& input_impl : input_impls_) {
|
||||
@ -116,28 +120,43 @@ class ZipDatasetOp : public DatasetOpKernel {
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_impl->GetNext(ctx, &input_tensors, end_of_sequence));
|
||||
if (*end_of_sequence) {
|
||||
return Status::OK();
|
||||
break;
|
||||
}
|
||||
out_tensors->insert(out_tensors->end(), input_tensors.begin(),
|
||||
input_tensors.end());
|
||||
}
|
||||
*end_of_sequence = false;
|
||||
if (*end_of_sequence) {
|
||||
out_tensors->clear();
|
||||
input_impls_.clear();
|
||||
} else {
|
||||
*end_of_sequence = false;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
for (auto& input_impl : input_impls_)
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl));
|
||||
if (input_impls_.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impls_empty"), ""));
|
||||
} else {
|
||||
for (auto& input_impl : input_impls_)
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
for (auto& input_impl : input_impls_)
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl));
|
||||
if (reader->Contains(full_name("input_impls_empty"))) {
|
||||
input_impls_.clear();
|
||||
} else {
|
||||
DCHECK_EQ(input_impls_.size(), dataset()->inputs_.size());
|
||||
for (auto& input_impl : input_impls_)
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user