Fix flakiness in input_lib_test.

The test could sometimes fail due to per-replica padding. To fix the test we ignore the padding.

PiperOrigin-RevId: 351876267
Change-Id: Iae4ca11e41beb9783c191ba026aabd0b7ea7bbab
This commit is contained in:
Andrew Audibert 2021-01-14 14:22:26 -08:00 committed by TensorFlower Gardener
parent ab789c1598
commit 94f9284bfb

View File

@ -1681,7 +1681,7 @@ class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
input_workers = input_lib.InputWorkers(worker_device_pairs)
dataset = dataset_ops.Dataset.range(50)
dataset = dataset_ops.Dataset.range(1, 50)
dataset = dataset.apply(
data_service_ops._distribute(
processing_mode="parallel_epochs",
@ -1696,10 +1696,13 @@ class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
for element in iterator:
local_results = distribution.experimental_local_results(element)
for result in local_results:
results.append(result.numpy())
# input_lib.distributed_dataset may add extra '0' elements to pad
# per-replica results.
if result.numpy() != 0:
results.append(result.numpy())
self.assertNotEmpty(results)
gathered = distribution.gather(constant_op.constant(results), axis=0)
self.assertCountEqual(self.num_workers * list(range(50)), gathered)
self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered)
if __name__ == "__main__":