diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index f8eeafc49da..4ed6bc63dca 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -234,7 +234,7 @@ class RepeatDatasetOp::Dataset : public DatasetBase { } Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); DCHECK(!*end_of_sequence || out_tensors->empty()); - if (first_call_ && *end_of_sequence) { + if (first_call_ && *end_of_sequence && !ctx->split_provider()) { // If the first call to GetNext() fails because the end // of sequence has been reached, we terminate the // iteration immediately. (Otherwise, this iterator diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py index c00cd9d8302..97c906e4788 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py @@ -447,6 +447,26 @@ class DataServiceOpsTest(data_service_test_base.TestBase, for i in range(num_elements): self.assertGreater(results[i], elements_to_read / num_elements / 2) + @combinations.generate(test_base.eager_only_combinations()) + def testDistributeDistributedEpochForeverRepeatFewElements(self): + num_workers = 5 + cluster = self.create_cluster(num_workers=num_workers) + # Less than the number of workers, so that some workers get zero elements on + # the first repetition. + num_elements = 1 + ds = dataset_ops.Dataset.range(num_elements).repeat() + ds = self.make_distributed_dataset( + ds, cluster, processing_mode="distributed_epoch") + it = iter(ds) + for _ in range(100): + self.assertEqual(next(it).numpy(), 0) + + # Stop all but one worker and check that we can still read. + for i in range(num_workers - 1): + cluster.workers[i]._stop() + for _ in range(100): + self.assertEqual(next(it).numpy(), 0) + @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochShuffleAndRepeat(self): cluster = self.create_cluster(num_workers=2) diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py b/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py index 0e48e1f4dd9..fe805850ec3 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py @@ -152,6 +152,11 @@ class TestCluster(object): def num_tasks_on_worker(self, worker_index=0): return self.workers[worker_index]._num_tasks() + def __del__(self): + # Destroy workers before the dispatcher for clean shutdown. + self.workers.clear() + del self.dispatcher + class TestBase(test_base.DatasetTestBase): """Base class for tf.data service tests."""