diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 8a85f96d4b1..9d9a34e178b 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -25,7 +25,9 @@ import numpy as np from tensorflow.python import tf2 from tensorflow.python.compat import compat +from tensorflow.python.data.experimental.ops import data_service_ops from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy +from tensorflow.python.data.experimental.service import server_lib from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import device_util @@ -1644,5 +1646,61 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, assert loop_num == len(expected) - 1 +class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase, + parameterized.TestCase): + """Tests for distributed iterators which read from tf.data service.""" + + def setUp(self): + super(DistributedIteratorTfDataServiceTest, self).setUp() + self.num_workers = 3 + if combinations.in_main_process(): + self.dispatcher = server_lib.DispatchServer() + self.workers = [] + for _ in range(self.num_workers): + self.workers.append( + server_lib.WorkerServer( + server_lib.WorkerConfig( + dispatcher_address=self.dispatcher.target.split("://")[1], + heartbeat_interval_ms=100, + dispatcher_timeout_ms=1000))) + combinations.env().tf_data_service_dispatcher = self.dispatcher.target + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.one_device_strategy, + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_two_gpus, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + strategy_combinations.multi_worker_mirrored_2x1_cpu, + ])) + def testTfDataService(self, distribution): + worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] + input_workers = input_lib.InputWorkers(worker_device_pairs) + + dataset = dataset_ops.Dataset.range(50) + dataset = dataset.apply( + data_service_ops._distribute( + processing_mode="parallel_epochs", + service=combinations.env().tf_data_service_dispatcher, + job_name="foo")) + + dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers, + distribution) + + iterator = iter(dist_dataset) + results = [] + for element in iterator: + local_results = distribution.experimental_local_results(element) + for result in local_results: + results.append(result.numpy()) + self.assertNotEmpty(results) + gathered = distribution.gather(constant_op.constant(results), axis=0) + self.assertCountEqual(self.num_workers * list(range(50)), gathered) + + if __name__ == "__main__": test_util.main()