Get rid of IteratorBase::is_exhausted flag since it is not possible to rely on it unless we lock each call to GetNext which is not preferable.

Each iterator now handles saving/restoring exhausted state.
As a guideline, we always reset the input_impl(s) when they get exhausted. This can be used as an indicator of exhausted-ness for non-terminal iterators. Also reduces memory overhead.
Each iterator should also handle calls to GetNextInternal when it is exhausted. Fixed this for some datasets.
Also fix a bug in dataset_serialization_test_base. We were not saving
a checkpoint after exhausting the iterator so verify_exhausted_iterator
was not really testing restoring an exhausted iterator.

PiperOrigin-RevId: 175253023
This commit is contained in:
Saurabh Saxena 2017-11-09 21:01:00 -08:00 committed by TensorFlower Gardener
parent badd356488
commit 3c41cb6bff
12 changed files with 120 additions and 57 deletions

View File

@ -337,11 +337,11 @@ class DatasetSerializationTestBase(test.TestCase):
num_iters = end - start num_iters = end - start
for _ in range(num_iters): for _ in range(num_iters):
outputs.append(sess.run(get_next_op)) outputs.append(sess.run(get_next_op))
self._save(sess, saver)
ckpt_saved = True
if i == len(break_points) and verify_exhausted: if i == len(break_points) and verify_exhausted:
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op) sess.run(get_next_op)
self._save(sess, saver)
ckpt_saved = True
return outputs return outputs

View File

@ -143,9 +143,13 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
// Each row of `batch_elements` is a tuple of tensors from the // Each row of `batch_elements` is a tuple of tensors from the
// input iterator. // input iterator.
std::vector<std::vector<Tensor>> batch_elements; std::vector<std::vector<Tensor>> batch_elements;
batch_elements.reserve(dataset()->batch_size_);
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
batch_elements.reserve(dataset()->batch_size_);
*end_of_sequence = false; *end_of_sequence = false;
for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence; for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
++i) { ++i) {
@ -154,6 +158,8 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
end_of_sequence)); end_of_sequence));
if (!*end_of_sequence) { if (!*end_of_sequence) {
batch_elements.emplace_back(std::move(batch_element_tuple)); batch_elements.emplace_back(std::move(batch_element_tuple));
} else {
input_impl_.reset();
} }
} }
} }
@ -194,14 +200,23 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
protected: protected:
Status SaveInternal(IteratorStateWriter* writer) override { Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_); mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); if (!input_impl_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
} else {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
}
return Status::OK(); return Status::OK();
} }
Status RestoreInternal(OpKernelContext* ctx, Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override { IteratorStateReader* reader) override {
mutex_lock l(mu_); mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); if (!reader->Contains(full_name("input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
return Status::OK(); return Status::OK();
} }

View File

@ -104,6 +104,10 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
std::vector<Tensor>* out_tensors, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override { bool* end_of_sequence) override {
mutex_lock l(mu_); mutex_lock l(mu_);
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (i_ < 2) { while (i_ < 2) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@ -140,7 +144,9 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
} else if (i_ == 2) { } else if (i_ == 2) {
input_impl_.reset(); input_impl_.reset();
} }
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); if (input_impl_) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
}
return Status::OK(); return Status::OK();
} }

View File

@ -126,7 +126,6 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
MakeDataset(ctx, input, another_input, output); MakeDataset(ctx, input, another_input, output);
} }
const char IteratorBase::kIteratorExhausted[] = "ITERATOR_EXHAUSTED";
const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH"; const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] = const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
"_DATASET_GRAPH_OUTPUT_NODE"; "_DATASET_GRAPH_OUTPUT_NODE";

View File

@ -306,27 +306,14 @@ class IteratorBase {
// Saves the state of this iterator. // Saves the state of this iterator.
virtual Status Save(IteratorStateWriter* writer) { virtual Status Save(IteratorStateWriter* writer) {
if (is_exhausted_) { return SaveInternal(writer);
LOG(INFO) << "Iterator exhausted.";
return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted);
} else {
return SaveInternal(writer);
}
} }
// Restores the state of this iterator. // Restores the state of this iterator.
virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) { virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
if (reader->Contains(kIteratorExhausted)) { return RestoreInternal(ctx, reader);
LOG(INFO) << "Iterator exhausted. Nothing to restore.";
is_exhausted_ = true;
return Status::OK();
} else {
return RestoreInternal(ctx, reader);
}
} }
static const char kIteratorExhausted[];
protected: protected:
// This is needed so that sub-classes of IteratorBase can call // This is needed so that sub-classes of IteratorBase can call
// `SaveInternal` on their parent iterators, e.g., in // `SaveInternal` on their parent iterators, e.g., in
@ -354,8 +341,6 @@ class IteratorBase {
IteratorStateReader* reader) { IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal"); return errors::Unimplemented("RestoreInternal");
} }
bool is_exhausted_ = false; // Whether the iterator has been exhausted.
}; };
// Represents a (potentially infinite) range of outputs, where each // Represents a (potentially infinite) range of outputs, where each
@ -491,10 +476,6 @@ class DatasetIterator : public IteratorBase {
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final { bool* end_of_sequence) final {
port::Tracing::TraceMe activity(params_.prefix); port::Tracing::TraceMe activity(params_.prefix);
if (is_exhausted_) {
*end_of_sequence = true;
return Status::OK();
}
return GetNextInternal(ctx, out_tensors, end_of_sequence); return GetNextInternal(ctx, out_tensors, end_of_sequence);
} }

View File

@ -99,7 +99,6 @@ class RangeDatasetOp : public DatasetOpKernel {
if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) || if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) ||
(dataset()->step_ < 0 && next_ <= dataset()->stop_)) { (dataset()->step_ < 0 && next_ <= dataset()->stop_)) {
*end_of_sequence = true; *end_of_sequence = true;
is_exhausted_ = true;
return Status::OK(); return Status::OK();
} }
Tensor value_tensor(cpu_allocator(), DT_INT64, {}); Tensor value_tensor(cpu_allocator(), DT_INT64, {});

View File

@ -402,7 +402,6 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
// Iteration ends when there are no more files to process. // Iteration ends when there are no more files to process.
if (current_file_index_ == dataset()->filenames_.size()) { if (current_file_index_ == dataset()->filenames_.size()) {
*end_of_sequence = true; *end_of_sequence = true;
is_exhausted_ = true;
return Status::OK(); return Status::OK();
} }

View File

@ -117,6 +117,10 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override { bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (i_ < dataset()->count_) { while (i_ < dataset()->count_) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@ -127,7 +131,6 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
input_impl_ = dataset()->input_->MakeIterator(prefix()); input_impl_ = dataset()->input_->MakeIterator(prefix());
} }
*end_of_sequence = true; *end_of_sequence = true;
is_exhausted_ = true;
input_impl_.reset(); input_impl_.reset();
return Status::OK(); return Status::OK();
} }
@ -136,7 +139,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
Status SaveInternal(IteratorStateWriter* writer) override { Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_); mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); if (!input_impl_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
} else {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
}
return Status::OK(); return Status::OK();
} }
@ -144,7 +152,11 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
IteratorStateReader* reader) override { IteratorStateReader* reader) override {
mutex_lock l(mu_); mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); if (!reader->Contains(full_name("input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
return Status::OK(); return Status::OK();
} }

View File

@ -105,8 +105,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_); mutex_lock l(mu_);
int64 start_micros = ctx->env()->NowMicros(); int64 start_micros = ctx->env()->NowMicros();
int64 num_log_entries = 0; int64 num_log_entries = 0;
while (!end_of_input_sequence_ && while (input_impl_ && buffer_.size() < dataset()->buffer_size_) {
buffer_.size() < dataset()->buffer_size_) {
if (ctx->env()->NowMicros() > if (ctx->env()->NowMicros() >
((num_log_entries + 1) * kLogIntervalMicros) + start_micros) { ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
num_log_entries++; num_log_entries++;
@ -114,9 +113,10 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
<< buffer_.size() << " of " << dataset()->buffer_size_; << buffer_.size() << " of " << dataset()->buffer_size_;
} }
std::vector<Tensor> input_element; std::vector<Tensor> input_element;
bool end_of_input_sequence;
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element, TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
&end_of_input_sequence_)); &end_of_input_sequence));
if (!end_of_input_sequence_) { if (!end_of_input_sequence) {
buffer_.emplace_back(std::move(input_element)); buffer_.emplace_back(std::move(input_element));
} else { } else {
input_impl_.reset(); input_impl_.reset();
@ -135,7 +135,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
std::swap(buffer_[index], buffer_.back()); std::swap(buffer_[index], buffer_.back());
buffer_.pop_back(); buffer_.pop_back();
} else { } else {
DCHECK(end_of_input_sequence_); DCHECK(input_impl_ == nullptr);
*end_of_sequence = true; *end_of_sequence = true;
} }
return Status::OK(); return Status::OK();
@ -148,11 +148,11 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
// Save the tensors in the buffer. // Save the tensors in the buffer.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("buffer_size"), buffer_.size())); writer->WriteScalar(full_name("buffer_size"), buffer_.size()));
for (int i = 0; i < buffer_.size(); i++) { for (size_t i = 0; i < buffer_.size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteScalar( TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("buffer_", i, "_size")), full_name(strings::StrCat("buffer_", i, "_size")),
buffer_[i].size())); buffer_[i].size()));
for (int j = 0; j < buffer_[i].size(); j++) { for (size_t j = 0; j < buffer_[i].size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor( TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("buffer_", i, "_", j)), full_name(strings::StrCat("buffer_", i, "_", j)),
buffer_[i][j])); buffer_[i][j]));
@ -165,7 +165,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
// Save input iterator if it hasn't been exhausted else write // Save input iterator if it hasn't been exhausted else write
// "end_of_input_sequence". // "end_of_input_sequence".
if (end_of_input_sequence_) { if (!input_impl_) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("end_of_input_sequence"), "")); writer->WriteScalar(full_name("end_of_input_sequence"), ""));
} else { } else {
@ -180,10 +180,15 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
buffer_.clear(); buffer_.clear();
// Restore the buffer. // Restore the buffer.
int64 buffer_size; size_t buffer_size;
TF_RETURN_IF_ERROR( {
reader->ReadScalar(full_name("buffer_size"), &buffer_size)); int64 temp;
for (int i = 0; i < buffer_size; i++) { TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("buffer_size"), &temp));
buffer_size = static_cast<size_t>(temp);
}
buffer_.reserve(buffer_size);
for (size_t i = 0; i < buffer_size; i++) {
int64 list_size; int64 list_size;
TF_RETURN_IF_ERROR(reader->ReadScalar( TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("buffer_", i, "_size")), &list_size)); full_name(strings::StrCat("buffer_", i, "_size")), &list_size));
@ -205,7 +210,6 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
input_impl_ = dataset()->input_->MakeIterator(prefix()); input_impl_ = dataset()->input_->MakeIterator(prefix());
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else { } else {
end_of_input_sequence_ = true;
input_impl_.reset(); input_impl_.reset();
} }
return Status::OK(); return Status::OK();
@ -230,7 +234,6 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
mutex mu_; mutex mu_;
std::vector<std::vector<Tensor>> buffer_ GUARDED_BY(mu_); std::vector<std::vector<Tensor>> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
bool end_of_input_sequence_ GUARDED_BY(mu_) = false;
const int64 seed_ GUARDED_BY(mu_); const int64 seed_ GUARDED_BY(mu_);
const int64 seed2_ GUARDED_BY(mu_); const int64 seed2_ GUARDED_BY(mu_);
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);

View File

@ -118,6 +118,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
bool* end_of_sequence) override { bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
// Keep calling GetNext(). TODO(vrv): Figure out a way to // Keep calling GetNext(). TODO(vrv): Figure out a way to
// skip records without reading, perhaps by adding an // skip records without reading, perhaps by adding an
// interface to iterator. // interface to iterator.
@ -138,6 +143,9 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
// Return GetNext() on the underlying iterator. // Return GetNext() on the underlying iterator.
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors,
end_of_sequence)); end_of_sequence));
if (*end_of_sequence) {
input_impl_.reset();
}
return Status::OK(); return Status::OK();
} }
@ -145,7 +153,12 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
Status SaveInternal(IteratorStateWriter* writer) override { Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_); mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); if (input_impl_) {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
}
return Status::OK(); return Status::OK();
} }
@ -153,7 +166,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
IteratorStateReader* reader) override { IteratorStateReader* reader) override {
mutex_lock l(mu_); mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); if (!reader->Contains(full_name("input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
return Status::OK(); return Status::OK();
} }

View File

@ -118,6 +118,10 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override { bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (i_ < dataset()->count_) { while (i_ < dataset()->count_) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@ -136,7 +140,12 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
Status SaveInternal(IteratorStateWriter* writer) override { Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_); mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); if (input_impl_) {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
}
return Status::OK(); return Status::OK();
} }
@ -144,7 +153,11 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
IteratorStateReader* reader) override { IteratorStateReader* reader) override {
mutex_lock l(mu_); mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); if (!reader->Contains(full_name("input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
return Status::OK(); return Status::OK();
} }

View File

@ -109,6 +109,10 @@ class ZipDatasetOp : public DatasetOpKernel {
std::vector<Tensor>* out_tensors, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override { bool* end_of_sequence) override {
mutex_lock l(mu_); mutex_lock l(mu_);
if (input_impls_.empty()) {
*end_of_sequence = true;
return Status::OK();
}
out_tensors->clear(); out_tensors->clear();
out_tensors->reserve(dataset()->output_dtypes().size()); out_tensors->reserve(dataset()->output_dtypes().size());
for (const auto& input_impl : input_impls_) { for (const auto& input_impl : input_impls_) {
@ -116,28 +120,43 @@ class ZipDatasetOp : public DatasetOpKernel {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
input_impl->GetNext(ctx, &input_tensors, end_of_sequence)); input_impl->GetNext(ctx, &input_tensors, end_of_sequence));
if (*end_of_sequence) { if (*end_of_sequence) {
return Status::OK(); break;
} }
out_tensors->insert(out_tensors->end(), input_tensors.begin(), out_tensors->insert(out_tensors->end(), input_tensors.begin(),
input_tensors.end()); input_tensors.end());
} }
*end_of_sequence = false; if (*end_of_sequence) {
out_tensors->clear();
input_impls_.clear();
} else {
*end_of_sequence = false;
}
return Status::OK(); return Status::OK();
} }
protected: protected:
Status SaveInternal(IteratorStateWriter* writer) override { Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_); mutex_lock l(mu_);
for (auto& input_impl : input_impls_) if (input_impls_.empty()) {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl)); TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impls_empty"), ""));
} else {
for (auto& input_impl : input_impls_)
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl));
}
return Status::OK(); return Status::OK();
} }
Status RestoreInternal(OpKernelContext* ctx, Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override { IteratorStateReader* reader) override {
mutex_lock l(mu_); mutex_lock l(mu_);
for (auto& input_impl : input_impls_) if (reader->Contains(full_name("input_impls_empty"))) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl)); input_impls_.clear();
} else {
DCHECK_EQ(input_impls_.size(), dataset()->inputs_.size());
for (auto& input_impl : input_impls_)
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl));
}
return Status::OK(); return Status::OK();
} }