[tf.data service] Fix bug with ForeverRepeat starting off empty.
ForeverRepeat has special handling for the case where its input dataset produces end_of_sequence for the first GetNext request. When this happens, we exit ForeverRepeat early, assuming that it will never be able to produce data. However, this is not true when using a split provider, since it is possible that the current repetition's data has already been consumed by other consumers, but a later repetition will yield data. PiperOrigin-RevId: 339571610 Change-Id: Ib94196fe4a7c8ec5d1a4b3e2422c20585707e933
This commit is contained in:
parent
e91327a5b5
commit
1e98f53e0a
@ -234,7 +234,7 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
DCHECK(!*end_of_sequence || out_tensors->empty());
|
||||
if (first_call_ && *end_of_sequence) {
|
||||
if (first_call_ && *end_of_sequence && !ctx->split_provider()) {
|
||||
// If the first call to GetNext() fails because the end
|
||||
// of sequence has been reached, we terminate the
|
||||
// iteration immediately. (Otherwise, this iterator
|
||||
|
@ -447,6 +447,26 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
for i in range(num_elements):
|
||||
self.assertGreater(results[i], elements_to_read / num_elements / 2)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpochForeverRepeatFewElements(self):
|
||||
num_workers = 5
|
||||
cluster = self.create_cluster(num_workers=num_workers)
|
||||
# Less than the number of workers, so that some workers get zero elements on
|
||||
# the first repetition.
|
||||
num_elements = 1
|
||||
ds = dataset_ops.Dataset.range(num_elements).repeat()
|
||||
ds = self.make_distributed_dataset(
|
||||
ds, cluster, processing_mode="distributed_epoch")
|
||||
it = iter(ds)
|
||||
for _ in range(100):
|
||||
self.assertEqual(next(it).numpy(), 0)
|
||||
|
||||
# Stop all but one worker and check that we can still read.
|
||||
for i in range(num_workers - 1):
|
||||
cluster.workers[i]._stop()
|
||||
for _ in range(100):
|
||||
self.assertEqual(next(it).numpy(), 0)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpochShuffleAndRepeat(self):
|
||||
cluster = self.create_cluster(num_workers=2)
|
||||
|
@ -152,6 +152,11 @@ class TestCluster(object):
|
||||
def num_tasks_on_worker(self, worker_index=0):
|
||||
return self.workers[worker_index]._num_tasks()
|
||||
|
||||
def __del__(self):
|
||||
# Destroy workers before the dispatcher for clean shutdown.
|
||||
self.workers.clear()
|
||||
del self.dispatcher
|
||||
|
||||
|
||||
class TestBase(test_base.DatasetTestBase):
|
||||
"""Base class for tf.data service tests."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user