Fix issue where CacheDataset GetNext returns OutOfRange.
Calling GetNext on a completed dataset iterator should return Status::OK, and set the end_of_sequence bool to true. Previously, we were setting it to true but still returning Status::OutOfRange. This caused breakage in ParallelInterleave, which fails when an input iterator returns a non-OK Status from GetNext. PiperOrigin-RevId: 255535560
This commit is contained in:
parent
e0d9dfd54b
commit
3398e887f5
@ -419,12 +419,12 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
|
||||
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
|
||||
if (s.ok() && !*end_of_sequence) RecordElement(ctx);
|
||||
RecordStop(ctx, /*start_output=*/true);
|
||||
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
|
||||
s = errors::Internal(
|
||||
"Iterator \"", params_.prefix,
|
||||
"\" returned OutOfRange without setting `*end_of_sequence`. This "
|
||||
"indicates that an error may have occurred. Original message: ",
|
||||
s.error_message());
|
||||
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
|
||||
s = errors::Internal("Iterator \"", params_.prefix,
|
||||
"\" returned `OutOfRange`. This indicates an "
|
||||
"implementation error as `OutOfRange` errors are not "
|
||||
"expected to be returned here. Original message: ",
|
||||
s.error_message());
|
||||
LOG(ERROR) << s;
|
||||
}
|
||||
return s;
|
||||
|
@ -472,6 +472,12 @@ class IteratorBase {
|
||||
// be stored in `*end_of_sequence`, and the content of
|
||||
// `*out_tensors` will be undefined.
|
||||
//
|
||||
// Implementations should never return `OutOfRange` error. If at end of
|
||||
// sequence, set `*end_of_sequence = true` and return `Status::OK()`.
|
||||
// Internally raised `OutOfRange` errors that do not imply end of sequence
|
||||
// should be converted to a different error type before being propagated to
|
||||
// the caller.
|
||||
//
|
||||
// This method is thread-safe.
|
||||
//
|
||||
// TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
|
||||
|
@ -212,9 +212,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
HandleEOF(EnsureLockFileExists(), end_of_sequence));
|
||||
TF_RETURN_IF_ERROR(HandleEOF(writer_->status(), end_of_sequence));
|
||||
*end_of_sequence = false;
|
||||
TF_RETURN_IF_ERROR(EnsureLockFileExists(end_of_sequence));
|
||||
if (*end_of_sequence) {
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(writer_->status());
|
||||
if (cur_index_ >= kMaxItems) {
|
||||
// As a courtesy, close the [truncated] cache file.
|
||||
Status s = Finish();
|
||||
@ -331,11 +334,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
Status EnsureLockFileExists() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (iteration_completed_)
|
||||
return errors::OutOfRange(
|
||||
"Attempting to call get_next after iteration should have "
|
||||
"finished.");
|
||||
Status EnsureLockFileExists(bool* end_of_sequence)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (iteration_completed_) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
if (lockfile_created_ && !iteration_completed_) return Status::OK();
|
||||
|
||||
// Perform rudimentary locking to help catch concurrent writes to the
|
||||
@ -419,13 +423,6 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleEOF(Status s, bool* end_of_sequence) {
|
||||
if (errors::IsOutOfRange(s)) {
|
||||
*end_of_sequence = true;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
size_t cur_index_ GUARDED_BY(mu_);
|
||||
// Index of the current shard. This gets incremented whenever a new
|
||||
|
@ -50,7 +50,7 @@ class FilterDatasetOpTest : public DatasetOpsTestBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates a new `ParallelInterleaveDataset` op kernel context.
|
||||
// Creates a new `FilterDataset` op kernel context.
|
||||
Status CreateFilterDatasetContext(
|
||||
OpKernel *const op_kernel,
|
||||
gtl::InlinedVector<TensorValue, 4> *const inputs,
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
@ -258,6 +259,17 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertDatasetProduces(dataset, [4 * x for x in range(100)])
|
||||
|
||||
def testParallelInterleaveCached(self):
|
||||
dataset = dataset_ops.Dataset.range(5)
|
||||
dataset = dataset.cache(os.path.join(self.get_temp_dir(), "cache_dir"))
|
||||
|
||||
def interleave_fn(x):
|
||||
return dataset_ops.Dataset.from_tensors(x)
|
||||
|
||||
dataset = dataset.interleave(
|
||||
interleave_fn, cycle_length=2, num_parallel_calls=2)
|
||||
self.assertDatasetProduces(dataset, list(range(5)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user