[tf.data service] Increase the number of client-side uncompress threads.

This is necessary to prevent uncompression from becoming the bottleneck.

The change required updating the unit tests because now the `distribute` transformation may prefetch up to 16 elements.

PiperOrigin-RevId: 316919714
Change-Id: I4e0c0b2985792a2a2a0f216de2143a645076b1c8
This commit is contained in:
Andrew Audibert 2020-06-17 10:46:23 -07:00 committed by TensorFlower Gardener
parent c1ae0ef7ce
commit 4747be646b
2 changed files with 20 additions and 12 deletions

View File

@ -240,8 +240,11 @@ def _distribute(processing_mode,
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
# TODO(b/157105111): Make this an autotuned parallel map when we have a way # TODO(b/157105111): Make this an autotuned parallel map when we have a way
# to limit memory usage. # to limit memory usage.
# The value 16 is chosen based on experience with pipelines that require
# more than 8 parallel calls to prevent this stage from being a bottleneck.
dataset = dataset.map( dataset = dataset.map(
lambda x: compression_ops.uncompress(x, output_spec=uncompressed_spec)) lambda x: compression_ops.uncompress(x, output_spec=uncompressed_spec),
num_parallel_calls=16)
# Disable autosharding for shared jobs. # Disable autosharding for shared jobs.
if job_name: if job_name:

View File

@ -201,12 +201,17 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
self._new_worker = server_lib.WorkerServer( self._new_worker = server_lib.WorkerServer(
port=port, master_address=self._master._address, protocol=PROTOCOL) port=port, master_address=self._master._address, protocol=PROTOCOL)
# The dataset starts over now that we read from the new worker. # There may have been some elements prefetched from the first worker
for i in range(num_elements):
val = next(iterator).numpy()
if val == midpoint and i != midpoint:
# There may have been one last element prefetched from the first worker
# before it was stopped. # before it was stopped.
while True:
val = next(iterator).numpy()
if val == 0:
break
# The dataset starts over now that we read from the new worker.
# TODO(b/157086991): Iterate until end of sequence when we support
# detecting lost workers.
for i in range(1, num_elements // 2):
val = next(iterator).numpy() val = next(iterator).numpy()
self.assertEqual(i, val) self.assertEqual(i, val)
@ -248,7 +253,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testSharedJobName(self): def testSharedJobName(self):
num_elements = 10 num_elements = 100
master_address = self.create_cluster(1) master_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
@ -256,7 +261,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
iter1 = iter(ds1) iter1 = iter(ds1)
iter2 = iter(ds2) iter2 = iter(ds2)
results = [] results = []
for _ in range(3): for _ in range(num_elements // 5):
results.append(next(iter1).numpy()) results.append(next(iter1).numpy())
results.append(next(iter2).numpy()) results.append(next(iter2).numpy())
for elem in iter1: for elem in iter1:
@ -291,7 +296,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testSharedJobNameRepeat(self): def testSharedJobNameRepeat(self):
num_elements = 10 num_elements = 100
num_repetitions = 3 num_repetitions = 3
master_address = self.create_cluster(1) master_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
@ -302,9 +307,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
results = [] results = []
iter1 = iter(ds1) iter1 = iter(ds1)
iter2 = iter(ds2) iter2 = iter(ds2)
for _ in range(((num_elements * num_repetitions) // 2) - 1): for _ in range((num_elements * num_repetitions) // 5):
results.append(next(iter1).numpy()) results.append(next(iter1).numpy())
for _ in range(((num_elements * num_repetitions) // 2) - 1): for _ in range((num_elements * num_repetitions) // 5):
results.append(next(iter2).numpy()) results.append(next(iter2).numpy())
for elem in iter1: for elem in iter1:
results.append(elem.numpy()) results.append(elem.numpy())