Internal change
PiperOrigin-RevId: 349608848 Change-Id: Idb0733edbdedc55ea588875f2f542c3f22af2954
This commit is contained in:
parent
0fa06203b9
commit
4ef33e3c38
@ -37,12 +37,6 @@ constexpr char kCurIteration[] = "i";
|
||||
constexpr char kInputImplEmpty[] = "input_impl_empty";
|
||||
constexpr char kUninitialized[] = "uninitialized";
|
||||
constexpr int64 kKnownRatio = 1;
|
||||
// Number of empty iterations before returning `end_of_sequence` from
|
||||
// ForeverRepeat. We choose 10000 to be low enough that it takes very little
|
||||
// time for datasets to detect infinite loops, but high enough to be reasonably
|
||||
// confident that the input dataset will continue to produce empty sequence
|
||||
// forever.
|
||||
constexpr int64 kMinEmptyForeverRepeatRetries = 10000;
|
||||
|
||||
class RepeatDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
@ -233,7 +227,6 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
|
||||
int64 empty_iterations = 0;
|
||||
do {
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
@ -241,17 +234,14 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
DCHECK(!*end_of_sequence || out_tensors->empty());
|
||||
if (*end_of_sequence &&
|
||||
empty_iterations >= kMinEmptyForeverRepeatRetries) {
|
||||
LOG(WARNING) << "Exiting repeat() early to avoid infinite loop. "
|
||||
"Upstream iterator was empty "
|
||||
<< empty_iterations << " times in a row";
|
||||
if (first_call_ && *end_of_sequence && !ctx->split_provider()) {
|
||||
// If the first call to GetNext() fails because the end
|
||||
// of sequence has been reached, we terminate the
|
||||
// iteration immediately. (Otherwise, this iterator
|
||||
// would loop infinitely and never produce a value.)
|
||||
input_impl_.reset();
|
||||
return Status::OK();
|
||||
}
|
||||
if (first_call_ && *end_of_sequence) {
|
||||
empty_iterations++;
|
||||
}
|
||||
first_call_ = false;
|
||||
if (!*end_of_sequence) {
|
||||
return s;
|
||||
|
@ -54,24 +54,6 @@ class RepeatTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInfiniteEmptyRepeat(self):
|
||||
dataset = dataset_ops.Dataset.range(0).repeat()
|
||||
self.assertDatasetProduces(dataset, [])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInfiniteProbablyEmptyRepeat(self):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.shuffle(100)
|
||||
dataset = dataset.take(1)
|
||||
dataset = dataset.filter(lambda x: x < 5)
|
||||
# At this point `dataset` has a 5% chance of being nonempty.
|
||||
dataset = dataset.repeat()
|
||||
dataset = dataset.take(10)
|
||||
# repeat() should try many times to get data from the dataset instead of
|
||||
# giving up and returning empty sequence.
|
||||
self.assertNotEmpty(self.getDatasetOutput(dataset))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRepeatRepeat(self):
|
||||
"""Test the composition of repeat datasets."""
|
||||
|
Loading…
Reference in New Issue
Block a user