[tf.data service] Extend round robin unit test to cover multiple replicas per host.
PiperOrigin-RevId: 351900031 Change-Id: I2d1eb18cc24da64581f324f66ff244a277cc37ce
This commit is contained in:
parent
08d6a26260
commit
9aedc576ae
@ -334,14 +334,18 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
# Round robin reads can cause slow cluster shutdown.
|
||||
GLOBAL_CLUSTERS.add(cluster)
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
|
||||
ds = ds.shuffle(num_elements)
|
||||
low_bucket_max = 30
|
||||
mid_bucket_max = 60
|
||||
bucket_boundaries = [low_bucket_max, mid_bucket_max]
|
||||
batch_size = 10
|
||||
num_consumers = 3
|
||||
num_consumer_hosts = 3
|
||||
replicas_per_consumer_host = 5
|
||||
num_consumers = num_consumer_hosts * replicas_per_consumer_host
|
||||
bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
|
||||
# Set up the dataset that will run on the tf.data workers.
|
||||
ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
|
||||
ds = ds.shuffle(num_elements)
|
||||
ds = ds.repeat()
|
||||
ds = ds.apply(
|
||||
grouping.bucket_by_sequence_length(
|
||||
lambda x: x,
|
||||
@ -354,28 +358,43 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
lambda _, x: dataset_ops.Dataset.from_tensors(x),
|
||||
window_size=num_consumers))
|
||||
ds = ds.flat_map(lambda x: x)
|
||||
ds = ds.repeat()
|
||||
|
||||
consumers = []
|
||||
for consumer_index in range(num_consumers):
|
||||
consumers.append(
|
||||
self.make_distributed_dataset(
|
||||
ds,
|
||||
cluster,
|
||||
job_name="test",
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers))
|
||||
# Use parallel interleave to read from consumers in parallel.
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(consumers)
|
||||
# Set up the per-consumer-host datasets. During each global step, we pull
|
||||
# `replicas_per_consumer_host` batches from each of these datasets.
|
||||
host_datasets = []
|
||||
for host_index in range(num_consumer_hosts):
|
||||
per_replica_datasets = []
|
||||
for i in range(replicas_per_consumer_host):
|
||||
consumer_index = host_index * replicas_per_consumer_host + i
|
||||
per_replica_datasets.append(
|
||||
self.make_distributed_dataset(
|
||||
ds,
|
||||
cluster,
|
||||
job_name="test",
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers))
|
||||
host_dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
per_replica_datasets)
|
||||
host_dataset = host_dataset.interleave(
|
||||
lambda x: x,
|
||||
cycle_length=len(per_replica_datasets),
|
||||
num_parallel_calls=len(per_replica_datasets),
|
||||
deterministic=True)
|
||||
host_datasets.append(host_dataset)
|
||||
|
||||
# Use parallel interleave to read from host datasets in parallel.
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(host_datasets)
|
||||
ds = ds.interleave(
|
||||
lambda x: x.prefetch(num_elements),
|
||||
cycle_length=num_consumers,
|
||||
num_parallel_calls=num_consumers)
|
||||
lambda x: x,
|
||||
block_length=replicas_per_consumer_host,
|
||||
cycle_length=len(host_datasets),
|
||||
num_parallel_calls=len(host_datasets),
|
||||
deterministic=True)
|
||||
|
||||
num_rounds = 10
|
||||
get_next = self.getNext(ds, requires_initialization=True)
|
||||
results = []
|
||||
for _ in range(num_rounds):
|
||||
for _ in range(num_rounds * num_consumers):
|
||||
results.append(self.evaluate(get_next()))
|
||||
|
||||
def get_bucket(elem):
|
||||
@ -385,8 +404,10 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
bucket_ind += 1
|
||||
return bucket_ind
|
||||
|
||||
# Check that the batches for each step contain elements from the same
|
||||
# bucket.
|
||||
for i in range(0, len(results), num_consumers):
|
||||
batches = results[num_consumers * i:num_consumers * i + num_consumers]
|
||||
batches = results[num_consumers * i:num_consumers * (i + 1)]
|
||||
bucket_inds = [get_bucket(batch[0]) for batch in batches]
|
||||
for bucket_ind in bucket_inds[1:]:
|
||||
self.assertEqual(bucket_inds[0], bucket_ind)
|
||||
|
Loading…
Reference in New Issue
Block a user