Fix flakiness in input_lib_test.
The test could sometimes fail due to per-replica padding. To fix the test we ignore the padding. PiperOrigin-RevId: 351876267 Change-Id: Iae4ca11e41beb9783c191ba026aabd0b7ea7bbab
This commit is contained in:
parent
ab789c1598
commit
94f9284bfb
@ -1681,7 +1681,7 @@ class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
|
|||||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||||
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||||
|
|
||||||
dataset = dataset_ops.Dataset.range(50)
|
dataset = dataset_ops.Dataset.range(1, 50)
|
||||||
dataset = dataset.apply(
|
dataset = dataset.apply(
|
||||||
data_service_ops._distribute(
|
data_service_ops._distribute(
|
||||||
processing_mode="parallel_epochs",
|
processing_mode="parallel_epochs",
|
||||||
@ -1696,10 +1696,13 @@ class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
|
|||||||
for element in iterator:
|
for element in iterator:
|
||||||
local_results = distribution.experimental_local_results(element)
|
local_results = distribution.experimental_local_results(element)
|
||||||
for result in local_results:
|
for result in local_results:
|
||||||
results.append(result.numpy())
|
# input_lib.distributed_dataset may add extra '0' elements to pad
|
||||||
|
# per-replica results.
|
||||||
|
if result.numpy() != 0:
|
||||||
|
results.append(result.numpy())
|
||||||
self.assertNotEmpty(results)
|
self.assertNotEmpty(results)
|
||||||
gathered = distribution.gather(constant_op.constant(results), axis=0)
|
gathered = distribution.gather(constant_op.constant(results), axis=0)
|
||||||
self.assertCountEqual(self.num_workers * list(range(50)), gathered)
|
self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user