[tf.data service] Increase the number of client-side uncompress threads.
This is necessary to prevent uncompression from becoming the bottleneck. The change required updating the unit tests because now the `distribute` transformation may prefetch up to 16 elements. PiperOrigin-RevId: 316919714 Change-Id: I4e0c0b2985792a2a2a0f216de2143a645076b1c8
This commit is contained in:
parent
c1ae0ef7ce
commit
4747be646b
|
@ -240,8 +240,11 @@ def _distribute(processing_mode,
|
|||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
||||
# TODO(b/157105111): Make this an autotuned parallel map when we have a way
|
||||
# to limit memory usage.
|
||||
# The value 16 is chosen based on experience with pipelines that require
|
||||
# more than 8 parallel calls to prevent this stage from being a bottleneck.
|
||||
dataset = dataset.map(
|
||||
lambda x: compression_ops.uncompress(x, output_spec=uncompressed_spec))
|
||||
lambda x: compression_ops.uncompress(x, output_spec=uncompressed_spec),
|
||||
num_parallel_calls=16)
|
||||
|
||||
# Disable autosharding for shared jobs.
|
||||
if job_name:
|
||||
|
|
|
@ -201,13 +201,18 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
self._new_worker = server_lib.WorkerServer(
|
||||
port=port, master_address=self._master._address, protocol=PROTOCOL)
|
||||
|
||||
# The dataset starts over now that we read from the new worker.
|
||||
for i in range(num_elements):
|
||||
# There may have been some elements prefetched from the first worker
|
||||
# before it was stopped.
|
||||
while True:
|
||||
val = next(iterator).numpy()
|
||||
if val == 0:
|
||||
break
|
||||
|
||||
# The dataset starts over now that we read from the new worker.
|
||||
# TODO(b/157086991): Iterate until end of sequence when we support
|
||||
# detecting lost workers.
|
||||
for i in range(1, num_elements // 2):
|
||||
val = next(iterator).numpy()
|
||||
if val == midpoint and i != midpoint:
|
||||
# There may have been one last element prefetched from the first worker
|
||||
# before it was stopped.
|
||||
val = next(iterator).numpy()
|
||||
self.assertEqual(i, val)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
|
@ -248,7 +253,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testSharedJobName(self):
|
||||
num_elements = 10
|
||||
num_elements = 100
|
||||
master_address = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
||||
|
@ -256,7 +261,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
iter1 = iter(ds1)
|
||||
iter2 = iter(ds2)
|
||||
results = []
|
||||
for _ in range(3):
|
||||
for _ in range(num_elements // 5):
|
||||
results.append(next(iter1).numpy())
|
||||
results.append(next(iter2).numpy())
|
||||
for elem in iter1:
|
||||
|
@ -291,7 +296,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testSharedJobNameRepeat(self):
|
||||
num_elements = 10
|
||||
num_elements = 100
|
||||
num_repetitions = 3
|
||||
master_address = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
|
@ -302,9 +307,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
results = []
|
||||
iter1 = iter(ds1)
|
||||
iter2 = iter(ds2)
|
||||
for _ in range(((num_elements * num_repetitions) // 2) - 1):
|
||||
for _ in range((num_elements * num_repetitions) // 5):
|
||||
results.append(next(iter1).numpy())
|
||||
for _ in range(((num_elements * num_repetitions) // 2) - 1):
|
||||
for _ in range((num_elements * num_repetitions) // 5):
|
||||
results.append(next(iter2).numpy())
|
||||
for elem in iter1:
|
||||
results.append(elem.numpy())
|
||||
|
|
Loading…
Reference in New Issue