Fix issue where CacheDataset GetNext returns OutOfRange.

Calling GetNext on a completed dataset iterator should
return Status::OK, and set the end_of_sequence bool to
true. Previously, we were setting it to true but still
returning Status::OutOfRange. This caused breakage in
ParallelInterleave, which fails when an input iterator
returns a non-OK Status from GetNext.

PiperOrigin-RevId: 255535560
This commit is contained in:
Andrew Audibert 2019-06-27 21:17:07 -07:00 committed by TensorFlower Gardener
parent e0d9dfd54b
commit 3398e887f5
5 changed files with 37 additions and 22 deletions

View File

@ -419,12 +419,12 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
if (s.ok() && !*end_of_sequence) RecordElement(ctx);
RecordStop(ctx, /*start_output=*/true);
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
s = errors::Internal(
"Iterator \"", params_.prefix,
"\" returned OutOfRange without setting `*end_of_sequence`. This "
"indicates that an error may have occurred. Original message: ",
s.error_message());
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
s = errors::Internal("Iterator \"", params_.prefix,
"\" returned `OutOfRange`. This indicates an "
"implementation error as `OutOfRange` errors are not "
"expected to be returned here. Original message: ",
s.error_message());
LOG(ERROR) << s;
}
return s;

View File

@ -472,6 +472,12 @@ class IteratorBase {
// be stored in `*end_of_sequence`, and the content of
// `*out_tensors` will be undefined.
//
// Implementations should never return `OutOfRange` error. If at end of
// sequence, set `*end_of_sequence = true` and return `Status::OK()`.
// Internally raised `OutOfRange` errors that do not imply end of sequence
// should be converted to a different error type before being propagated to
// the caller.
//
// This method is thread-safe.
//
// TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and

View File

@ -212,9 +212,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(
HandleEOF(EnsureLockFileExists(), end_of_sequence));
TF_RETURN_IF_ERROR(HandleEOF(writer_->status(), end_of_sequence));
*end_of_sequence = false;
TF_RETURN_IF_ERROR(EnsureLockFileExists(end_of_sequence));
if (*end_of_sequence) {
return Status::OK();
}
TF_RETURN_IF_ERROR(writer_->status());
if (cur_index_ >= kMaxItems) {
// As a courtesy, close the [truncated] cache file.
Status s = Finish();
@ -331,11 +334,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
}
private:
Status EnsureLockFileExists() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (iteration_completed_)
return errors::OutOfRange(
"Attempting to call get_next after iteration should have "
"finished.");
Status EnsureLockFileExists(bool* end_of_sequence)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (iteration_completed_) {
*end_of_sequence = true;
return Status::OK();
}
if (lockfile_created_ && !iteration_completed_) return Status::OK();
// Perform rudimentary locking to help catch concurrent writes to the
@ -419,13 +423,6 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
Status HandleEOF(Status s, bool* end_of_sequence) {
if (errors::IsOutOfRange(s)) {
*end_of_sequence = true;
}
return s;
}
mutex mu_;
size_t cur_index_ GUARDED_BY(mu_);
// Index of the current shard. This gets incremented whenever a new

View File

@ -50,7 +50,7 @@ class FilterDatasetOpTest : public DatasetOpsTestBase {
return Status::OK();
}
// Creates a new `ParallelInterleaveDataset` op kernel context.
// Creates a new `FilterDataset` op kernel context.
Status CreateFilterDatasetContext(
OpKernel *const op_kernel,
gtl::InlinedVector<TensorValue, 4> *const inputs,

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import multiprocessing
import os
from absl.testing import parameterized
import numpy as np
@ -258,6 +259,17 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(dataset, [4 * x for x in range(100)])
def testParallelInterleaveCached(self):
dataset = dataset_ops.Dataset.range(5)
dataset = dataset.cache(os.path.join(self.get_temp_dir(), "cache_dir"))
def interleave_fn(x):
return dataset_ops.Dataset.from_tensors(x)
dataset = dataset.interleave(
interleave_fn, cycle_length=2, num_parallel_calls=2)
self.assertDatasetProduces(dataset, list(range(5)))
if __name__ == "__main__":
test.main()