Fix issue where CacheDataset GetNext returns OutOfRange.
Calling GetNext on a completed dataset iterator should return Status::OK, and set the end_of_sequence bool to true. Previously, we were setting it to true but still returning Status::OutOfRange. This caused breakage in ParallelInterleave, which fails when an input iterator returns a non-OK Status from GetNext. PiperOrigin-RevId: 255535560
This commit is contained in:
parent
e0d9dfd54b
commit
3398e887f5
@ -419,12 +419,12 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
|
|||||||
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
|
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
|
||||||
if (s.ok() && !*end_of_sequence) RecordElement(ctx);
|
if (s.ok() && !*end_of_sequence) RecordElement(ctx);
|
||||||
RecordStop(ctx, /*start_output=*/true);
|
RecordStop(ctx, /*start_output=*/true);
|
||||||
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
|
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
|
||||||
s = errors::Internal(
|
s = errors::Internal("Iterator \"", params_.prefix,
|
||||||
"Iterator \"", params_.prefix,
|
"\" returned `OutOfRange`. This indicates an "
|
||||||
"\" returned OutOfRange without setting `*end_of_sequence`. This "
|
"implementation error as `OutOfRange` errors are not "
|
||||||
"indicates that an error may have occurred. Original message: ",
|
"expected to be returned here. Original message: ",
|
||||||
s.error_message());
|
s.error_message());
|
||||||
LOG(ERROR) << s;
|
LOG(ERROR) << s;
|
||||||
}
|
}
|
||||||
return s;
|
return s;
|
||||||
|
@ -472,6 +472,12 @@ class IteratorBase {
|
|||||||
// be stored in `*end_of_sequence`, and the content of
|
// be stored in `*end_of_sequence`, and the content of
|
||||||
// `*out_tensors` will be undefined.
|
// `*out_tensors` will be undefined.
|
||||||
//
|
//
|
||||||
|
// Implementations should never return `OutOfRange` error. If at end of
|
||||||
|
// sequence, set `*end_of_sequence = true` and return `Status::OK()`.
|
||||||
|
// Internally raised `OutOfRange` errors that do not imply end of sequence
|
||||||
|
// should be converted to a different error type before being propagated to
|
||||||
|
// the caller.
|
||||||
|
//
|
||||||
// This method is thread-safe.
|
// This method is thread-safe.
|
||||||
//
|
//
|
||||||
// TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
|
// TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
|
||||||
|
@ -212,9 +212,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(
|
*end_of_sequence = false;
|
||||||
HandleEOF(EnsureLockFileExists(), end_of_sequence));
|
TF_RETURN_IF_ERROR(EnsureLockFileExists(end_of_sequence));
|
||||||
TF_RETURN_IF_ERROR(HandleEOF(writer_->status(), end_of_sequence));
|
if (*end_of_sequence) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(writer_->status());
|
||||||
if (cur_index_ >= kMaxItems) {
|
if (cur_index_ >= kMaxItems) {
|
||||||
// As a courtesy, close the [truncated] cache file.
|
// As a courtesy, close the [truncated] cache file.
|
||||||
Status s = Finish();
|
Status s = Finish();
|
||||||
@ -331,11 +334,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status EnsureLockFileExists() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
Status EnsureLockFileExists(bool* end_of_sequence)
|
||||||
if (iteration_completed_)
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
return errors::OutOfRange(
|
if (iteration_completed_) {
|
||||||
"Attempting to call get_next after iteration should have "
|
*end_of_sequence = true;
|
||||||
"finished.");
|
return Status::OK();
|
||||||
|
}
|
||||||
if (lockfile_created_ && !iteration_completed_) return Status::OK();
|
if (lockfile_created_ && !iteration_completed_) return Status::OK();
|
||||||
|
|
||||||
// Perform rudimentary locking to help catch concurrent writes to the
|
// Perform rudimentary locking to help catch concurrent writes to the
|
||||||
@ -419,13 +423,6 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status HandleEOF(Status s, bool* end_of_sequence) {
|
|
||||||
if (errors::IsOutOfRange(s)) {
|
|
||||||
*end_of_sequence = true;
|
|
||||||
}
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
size_t cur_index_ GUARDED_BY(mu_);
|
size_t cur_index_ GUARDED_BY(mu_);
|
||||||
// Index of the current shard. This gets incremented whenever a new
|
// Index of the current shard. This gets incremented whenever a new
|
||||||
|
@ -50,7 +50,7 @@ class FilterDatasetOpTest : public DatasetOpsTestBase {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a new `ParallelInterleaveDataset` op kernel context.
|
// Creates a new `FilterDataset` op kernel context.
|
||||||
Status CreateFilterDatasetContext(
|
Status CreateFilterDatasetContext(
|
||||||
OpKernel *const op_kernel,
|
OpKernel *const op_kernel,
|
||||||
gtl::InlinedVector<TensorValue, 4> *const inputs,
|
gtl::InlinedVector<TensorValue, 4> *const inputs,
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -258,6 +259,17 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertDatasetProduces(dataset, [4 * x for x in range(100)])
|
self.assertDatasetProduces(dataset, [4 * x for x in range(100)])
|
||||||
|
|
||||||
|
def testParallelInterleaveCached(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(5)
|
||||||
|
dataset = dataset.cache(os.path.join(self.get_temp_dir(), "cache_dir"))
|
||||||
|
|
||||||
|
def interleave_fn(x):
|
||||||
|
return dataset_ops.Dataset.from_tensors(x)
|
||||||
|
|
||||||
|
dataset = dataset.interleave(
|
||||||
|
interleave_fn, cycle_length=2, num_parallel_calls=2)
|
||||||
|
self.assertDatasetProduces(dataset, list(range(5)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user