[tf.data service] Fix bug with ForeverRepeat starting off empty.

ForeverRepeat has special handling for the case where its input dataset produces end_of_sequence for the first GetNext request. When this happens, we exit ForeverRepeat early, assuming that it will never be able to produce data. However, this is not true when using a split provider, since it is possible that the current repetition's data has already been consumed by other consumers, but a later repetition will yield data.

PiperOrigin-RevId: 339571610
Change-Id: Ib94196fe4a7c8ec5d1a4b3e2422c20585707e933
This commit is contained in:
Andrew Audibert 2020-10-28 17:31:21 -07:00 committed by TensorFlower Gardener
parent e91327a5b5
commit 1e98f53e0a
3 changed files with 26 additions and 1 deletions

View File

@ -234,7 +234,7 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
}
Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
DCHECK(!*end_of_sequence || out_tensors->empty());
if (first_call_ && *end_of_sequence) {
if (first_call_ && *end_of_sequence && !ctx->split_provider()) {
// If the first call to GetNext() fails because the end
// of sequence has been reached, we terminate the
// iteration immediately. (Otherwise, this iterator

View File

@ -447,6 +447,26 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
for i in range(num_elements):
self.assertGreater(results[i], elements_to_read / num_elements / 2)
@combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpochForeverRepeatFewElements(self):
num_workers = 5
cluster = self.create_cluster(num_workers=num_workers)
# Less than the number of workers, so that some workers get zero elements on
# the first repetition.
num_elements = 1
ds = dataset_ops.Dataset.range(num_elements).repeat()
ds = self.make_distributed_dataset(
ds, cluster, processing_mode="distributed_epoch")
it = iter(ds)
for _ in range(100):
self.assertEqual(next(it).numpy(), 0)
# Stop all but one worker and check that we can still read.
for i in range(num_workers - 1):
cluster.workers[i]._stop()
for _ in range(100):
self.assertEqual(next(it).numpy(), 0)
@combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpochShuffleAndRepeat(self):
cluster = self.create_cluster(num_workers=2)

View File

@ -152,6 +152,11 @@ class TestCluster(object):
def num_tasks_on_worker(self, worker_index=0):
return self.workers[worker_index]._num_tasks()
def __del__(self):
# Destroy workers before the dispatcher for clean shutdown.
self.workers.clear()
del self.dispatcher
class TestBase(test_base.DatasetTestBase):
"""Base class for tf.data service tests."""