Add tests for using tf.data service with distribution strategies.
PiperOrigin-RevId: 350215726 Change-Id: I8329af407456bf8a6a5268f2671a39bdd7e09578
This commit is contained in:
parent
8232f94fa3
commit
866d081cb4
@ -25,7 +25,9 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.compat import compat
|
from tensorflow.python.compat import compat
|
||||||
|
from tensorflow.python.data.experimental.ops import data_service_ops
|
||||||
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||||
|
from tensorflow.python.data.experimental.service import server_lib
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
@ -1644,5 +1646,61 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase,
|
|||||||
assert loop_num == len(expected) - 1
|
assert loop_num == len(expected) - 1
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
"""Tests for distributed iterators which read from tf.data service."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(DistributedIteratorTfDataServiceTest, self).setUp()
|
||||||
|
self.num_workers = 3
|
||||||
|
if combinations.in_main_process():
|
||||||
|
self.dispatcher = server_lib.DispatchServer()
|
||||||
|
self.workers = []
|
||||||
|
for _ in range(self.num_workers):
|
||||||
|
self.workers.append(
|
||||||
|
server_lib.WorkerServer(
|
||||||
|
server_lib.WorkerConfig(
|
||||||
|
dispatcher_address=self.dispatcher.target.split("://")[1],
|
||||||
|
heartbeat_interval_ms=100,
|
||||||
|
dispatcher_timeout_ms=1000)))
|
||||||
|
combinations.env().tf_data_service_dispatcher = self.dispatcher.target
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=["eager"],
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.one_device_strategy,
|
||||||
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||||
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||||
|
]))
|
||||||
|
def testTfDataService(self, distribution):
|
||||||
|
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||||
|
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(50)
|
||||||
|
dataset = dataset.apply(
|
||||||
|
data_service_ops._distribute(
|
||||||
|
processing_mode="parallel_epochs",
|
||||||
|
service=combinations.env().tf_data_service_dispatcher,
|
||||||
|
job_name="foo"))
|
||||||
|
|
||||||
|
dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
|
||||||
|
distribution)
|
||||||
|
|
||||||
|
iterator = iter(dist_dataset)
|
||||||
|
results = []
|
||||||
|
for element in iterator:
|
||||||
|
local_results = distribution.experimental_local_results(element)
|
||||||
|
for result in local_results:
|
||||||
|
results.append(result.numpy())
|
||||||
|
self.assertNotEmpty(results)
|
||||||
|
gathered = distribution.gather(constant_op.constant(results), axis=0)
|
||||||
|
self.assertCountEqual(self.num_workers * list(range(50)), gathered)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_util.main()
|
test_util.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user