Add tests for using tf.data service with distribution strategies.

PiperOrigin-RevId: 350215726
Change-Id: I8329af407456bf8a6a5268f2671a39bdd7e09578
This commit is contained in:
Andrew Audibert 2021-01-05 14:06:49 -08:00 committed by TensorFlower Gardener
parent 8232f94fa3
commit 866d081cb4

View File

@ -25,7 +25,9 @@ import numpy as np
from tensorflow.python import tf2
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import data_service_ops
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
from tensorflow.python.data.experimental.service import server_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import device_util
@ -1644,5 +1646,61 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase,
assert loop_num == len(expected) - 1
class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
parameterized.TestCase):
"""Tests for distributed iterators which read from tf.data service."""
def setUp(self):
super(DistributedIteratorTfDataServiceTest, self).setUp()
self.num_workers = 3
if combinations.in_main_process():
self.dispatcher = server_lib.DispatchServer()
self.workers = []
for _ in range(self.num_workers):
self.workers.append(
server_lib.WorkerServer(
server_lib.WorkerConfig(
dispatcher_address=self.dispatcher.target.split("://")[1],
heartbeat_interval_ms=100,
dispatcher_timeout_ms=1000)))
combinations.env().tf_data_service_dispatcher = self.dispatcher.target
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
]))
def testTfDataService(self, distribution):
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
input_workers = input_lib.InputWorkers(worker_device_pairs)
dataset = dataset_ops.Dataset.range(50)
dataset = dataset.apply(
data_service_ops._distribute(
processing_mode="parallel_epochs",
service=combinations.env().tf_data_service_dispatcher,
job_name="foo"))
dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
distribution)
iterator = iter(dist_dataset)
results = []
for element in iterator:
local_results = distribution.experimental_local_results(element)
for result in local_results:
results.append(result.numpy())
self.assertNotEmpty(results)
gathered = distribution.gather(constant_op.constant(results), axis=0)
self.assertCountEqual(self.num_workers * list(range(50)), gathered)
if __name__ == "__main__":
test_util.main()