Fix flakiness in input_lib_test.
The test could sometimes fail due to per-replica padding. To fix the test we ignore the padding. PiperOrigin-RevId: 351876267 Change-Id: Iae4ca11e41beb9783c191ba026aabd0b7ea7bbab
This commit is contained in:
parent
ab789c1598
commit
94f9284bfb
@ -1681,7 +1681,7 @@ class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(50)
|
||||
dataset = dataset_ops.Dataset.range(1, 50)
|
||||
dataset = dataset.apply(
|
||||
data_service_ops._distribute(
|
||||
processing_mode="parallel_epochs",
|
||||
@ -1696,10 +1696,13 @@ class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
|
||||
for element in iterator:
|
||||
local_results = distribution.experimental_local_results(element)
|
||||
for result in local_results:
|
||||
results.append(result.numpy())
|
||||
# input_lib.distributed_dataset may add extra '0' elements to pad
|
||||
# per-replica results.
|
||||
if result.numpy() != 0:
|
||||
results.append(result.numpy())
|
||||
self.assertNotEmpty(results)
|
||||
gathered = distribution.gather(constant_op.constant(results), axis=0)
|
||||
self.assertCountEqual(self.num_workers * list(range(50)), gathered)
|
||||
self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user