Internal change

PiperOrigin-RevId: 349608848
Change-Id: Idb0733edbdedc55ea588875f2f542c3f22af2954
This commit is contained in:
A. Unique TensorFlower 2020-12-30 15:30:04 -08:00 committed by TensorFlower Gardener
parent 0fa06203b9
commit 4ef33e3c38
2 changed files with 5 additions and 33 deletions

View File

@ -37,12 +37,6 @@ constexpr char kCurIteration[] = "i";
constexpr char kInputImplEmpty[] = "input_impl_empty";
constexpr char kUninitialized[] = "uninitialized";
constexpr int64 kKnownRatio = 1;
// Number of empty iterations before returning `end_of_sequence` from
// ForeverRepeat. We choose 10000 to be low enough that it takes very little
// time for datasets to detect infinite loops, but high enough to be reasonably
// confident that the input dataset will continue to produce empty sequence
// forever.
constexpr int64 kMinEmptyForeverRepeatRetries = 10000;
class RepeatDatasetOp::Dataset : public DatasetBase {
public:
@ -233,7 +227,6 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
int64 empty_iterations = 0;
do {
if (!input_impl_) {
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
@ -241,17 +234,14 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
}
Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
DCHECK(!*end_of_sequence || out_tensors->empty());
if (*end_of_sequence &&
empty_iterations >= kMinEmptyForeverRepeatRetries) {
LOG(WARNING) << "Exiting repeat() early to avoid infinite loop. "
"Upstream iterator was empty "
<< empty_iterations << " times in a row";
if (first_call_ && *end_of_sequence && !ctx->split_provider()) {
// If the first call to GetNext() fails because the end
// of sequence has been reached, we terminate the
// iteration immediately. (Otherwise, this iterator
// would loop infinitely and never produce a value.)
input_impl_.reset();
return Status::OK();
}
if (first_call_ && *end_of_sequence) {
empty_iterations++;
}
first_call_ = false;
if (!*end_of_sequence) {
return s;

View File

@ -54,24 +54,6 @@ class RepeatTest(test_base.DatasetTestBase, parameterized.TestCase):
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
@combinations.generate(test_base.default_test_combinations())
def testInfiniteEmptyRepeat(self):
dataset = dataset_ops.Dataset.range(0).repeat()
self.assertDatasetProduces(dataset, [])
@combinations.generate(test_base.default_test_combinations())
def testInfiniteProbablyEmptyRepeat(self):
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.shuffle(100)
dataset = dataset.take(1)
dataset = dataset.filter(lambda x: x < 5)
# At this point `dataset` has a 5% chance of being nonempty.
dataset = dataset.repeat()
dataset = dataset.take(10)
# repeat() should try many times to get data from the dataset instead of
# giving up and returning empty sequence.
self.assertNotEmpty(self.getDatasetOutput(dataset))
@combinations.generate(test_base.default_test_combinations())
def testRepeatRepeat(self):
"""Test the composition of repeat datasets."""