[tf.data service] Extend round robin unit test to cover multiple replicas per host.

PiperOrigin-RevId: 351900031
Change-Id: I2d1eb18cc24da64581f324f66ff244a277cc37ce
This commit is contained in:
Andrew Audibert 2021-01-14 16:22:38 -08:00 committed by TensorFlower Gardener
parent 08d6a26260
commit 9aedc576ae

View File

@ -334,14 +334,18 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
# Round robin reads can cause slow cluster shutdown.
GLOBAL_CLUSTERS.add(cluster)
num_elements = 100
ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
ds = ds.shuffle(num_elements)
low_bucket_max = 30
mid_bucket_max = 60
bucket_boundaries = [low_bucket_max, mid_bucket_max]
batch_size = 10
num_consumers = 3
num_consumer_hosts = 3
replicas_per_consumer_host = 5
num_consumers = num_consumer_hosts * replicas_per_consumer_host
bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
# Set up the dataset that will run on the tf.data workers.
ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
ds = ds.shuffle(num_elements)
ds = ds.repeat()
ds = ds.apply(
grouping.bucket_by_sequence_length(
lambda x: x,
@ -354,28 +358,43 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
lambda _, x: dataset_ops.Dataset.from_tensors(x),
window_size=num_consumers))
ds = ds.flat_map(lambda x: x)
ds = ds.repeat()
consumers = []
for consumer_index in range(num_consumers):
consumers.append(
self.make_distributed_dataset(
ds,
cluster,
job_name="test",
consumer_index=consumer_index,
num_consumers=num_consumers))
# Use parallel interleave to read from consumers in parallel.
ds = dataset_ops.Dataset.from_tensor_slices(consumers)
# Set up the per-consumer-host datasets. During each global step, we pull
# `replicas_per_consumer_host` batches from each of these datasets.
host_datasets = []
for host_index in range(num_consumer_hosts):
per_replica_datasets = []
for i in range(replicas_per_consumer_host):
consumer_index = host_index * replicas_per_consumer_host + i
per_replica_datasets.append(
self.make_distributed_dataset(
ds,
cluster,
job_name="test",
consumer_index=consumer_index,
num_consumers=num_consumers))
host_dataset = dataset_ops.Dataset.from_tensor_slices(
per_replica_datasets)
host_dataset = host_dataset.interleave(
lambda x: x,
cycle_length=len(per_replica_datasets),
num_parallel_calls=len(per_replica_datasets),
deterministic=True)
host_datasets.append(host_dataset)
# Use parallel interleave to read from host datasets in parallel.
ds = dataset_ops.Dataset.from_tensor_slices(host_datasets)
ds = ds.interleave(
lambda x: x.prefetch(num_elements),
cycle_length=num_consumers,
num_parallel_calls=num_consumers)
lambda x: x,
block_length=replicas_per_consumer_host,
cycle_length=len(host_datasets),
num_parallel_calls=len(host_datasets),
deterministic=True)
num_rounds = 10
get_next = self.getNext(ds, requires_initialization=True)
results = []
for _ in range(num_rounds):
for _ in range(num_rounds * num_consumers):
results.append(self.evaluate(get_next()))
def get_bucket(elem):
@ -385,8 +404,10 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
bucket_ind += 1
return bucket_ind
# Check that the batches for each step contain elements from the same
# bucket.
for i in range(0, len(results), num_consumers):
batches = results[num_consumers * i:num_consumers * i + num_consumers]
batches = results[num_consumers * i:num_consumers * (i + 1)]
bucket_inds = [get_bucket(batch[0]) for batch in batches]
for bucket_ind in bucket_inds[1:]:
self.assertEqual(bucket_inds[0], bucket_ind)