Add tests for using tf.data service with distribution strategies.
PiperOrigin-RevId: 350215726 Change-Id: I8329af407456bf8a6a5268f2671a39bdd7e09578
This commit is contained in:
parent
8232f94fa3
commit
866d081cb4
@ -25,7 +25,9 @@ import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import data_service_ops
|
||||
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||
from tensorflow.python.data.experimental.service import server_lib
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import device_util
|
||||
@ -1644,5 +1646,61 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase,
|
||||
assert loop_num == len(expected) - 1
|
||||
|
||||
|
||||
class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
|
||||
parameterized.TestCase):
|
||||
"""Tests for distributed iterators which read from tf.data service."""
|
||||
|
||||
def setUp(self):
|
||||
super(DistributedIteratorTfDataServiceTest, self).setUp()
|
||||
self.num_workers = 3
|
||||
if combinations.in_main_process():
|
||||
self.dispatcher = server_lib.DispatchServer()
|
||||
self.workers = []
|
||||
for _ in range(self.num_workers):
|
||||
self.workers.append(
|
||||
server_lib.WorkerServer(
|
||||
server_lib.WorkerConfig(
|
||||
dispatcher_address=self.dispatcher.target.split("://")[1],
|
||||
heartbeat_interval_ms=100,
|
||||
dispatcher_timeout_ms=1000)))
|
||||
combinations.env().tf_data_service_dispatcher = self.dispatcher.target
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
]))
|
||||
def testTfDataService(self, distribution):
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(50)
|
||||
dataset = dataset.apply(
|
||||
data_service_ops._distribute(
|
||||
processing_mode="parallel_epochs",
|
||||
service=combinations.env().tf_data_service_dispatcher,
|
||||
job_name="foo"))
|
||||
|
||||
dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
|
||||
distribution)
|
||||
|
||||
iterator = iter(dist_dataset)
|
||||
results = []
|
||||
for element in iterator:
|
||||
local_results = distribution.experimental_local_results(element)
|
||||
for result in local_results:
|
||||
results.append(result.numpy())
|
||||
self.assertNotEmpty(results)
|
||||
gathered = distribution.gather(constant_op.constant(results), axis=0)
|
||||
self.assertCountEqual(self.num_workers * list(range(50)), gathered)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_util.main()
|
||||
|
Loading…
Reference in New Issue
Block a user