[tf.data service] Increase the number of client-side uncompress threads.
This is necessary to prevent uncompression from becoming the bottleneck. The change required updating the unit tests because now the `distribute` transformation may prefetch up to 16 elements. PiperOrigin-RevId: 316919714 Change-Id: I4e0c0b2985792a2a2a0f216de2143a645076b1c8
This commit is contained in:
parent
c1ae0ef7ce
commit
4747be646b
|
@ -240,8 +240,11 @@ def _distribute(processing_mode,
|
||||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
||||||
# TODO(b/157105111): Make this an autotuned parallel map when we have a way
|
# TODO(b/157105111): Make this an autotuned parallel map when we have a way
|
||||||
# to limit memory usage.
|
# to limit memory usage.
|
||||||
|
# The value 16 is chosen based on experience with pipelines that require
|
||||||
|
# more than 8 parallel calls to prevent this stage from being a bottleneck.
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
lambda x: compression_ops.uncompress(x, output_spec=uncompressed_spec))
|
lambda x: compression_ops.uncompress(x, output_spec=uncompressed_spec),
|
||||||
|
num_parallel_calls=16)
|
||||||
|
|
||||||
# Disable autosharding for shared jobs.
|
# Disable autosharding for shared jobs.
|
||||||
if job_name:
|
if job_name:
|
||||||
|
|
|
@ -201,12 +201,17 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
self._new_worker = server_lib.WorkerServer(
|
self._new_worker = server_lib.WorkerServer(
|
||||||
port=port, master_address=self._master._address, protocol=PROTOCOL)
|
port=port, master_address=self._master._address, protocol=PROTOCOL)
|
||||||
|
|
||||||
# The dataset starts over now that we read from the new worker.
|
# There may have been some elements prefetched from the first worker
|
||||||
for i in range(num_elements):
|
|
||||||
val = next(iterator).numpy()
|
|
||||||
if val == midpoint and i != midpoint:
|
|
||||||
# There may have been one last element prefetched from the first worker
|
|
||||||
# before it was stopped.
|
# before it was stopped.
|
||||||
|
while True:
|
||||||
|
val = next(iterator).numpy()
|
||||||
|
if val == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# The dataset starts over now that we read from the new worker.
|
||||||
|
# TODO(b/157086991): Iterate until end of sequence when we support
|
||||||
|
# detecting lost workers.
|
||||||
|
for i in range(1, num_elements // 2):
|
||||||
val = next(iterator).numpy()
|
val = next(iterator).numpy()
|
||||||
self.assertEqual(i, val)
|
self.assertEqual(i, val)
|
||||||
|
|
||||||
|
@ -248,7 +253,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testSharedJobName(self):
|
def testSharedJobName(self):
|
||||||
num_elements = 10
|
num_elements = 100
|
||||||
master_address = self.create_cluster(1)
|
master_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
||||||
|
@ -256,7 +261,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
iter1 = iter(ds1)
|
iter1 = iter(ds1)
|
||||||
iter2 = iter(ds2)
|
iter2 = iter(ds2)
|
||||||
results = []
|
results = []
|
||||||
for _ in range(3):
|
for _ in range(num_elements // 5):
|
||||||
results.append(next(iter1).numpy())
|
results.append(next(iter1).numpy())
|
||||||
results.append(next(iter2).numpy())
|
results.append(next(iter2).numpy())
|
||||||
for elem in iter1:
|
for elem in iter1:
|
||||||
|
@ -291,7 +296,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testSharedJobNameRepeat(self):
|
def testSharedJobNameRepeat(self):
|
||||||
num_elements = 10
|
num_elements = 100
|
||||||
num_repetitions = 3
|
num_repetitions = 3
|
||||||
master_address = self.create_cluster(1)
|
master_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
|
@ -302,9 +307,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
results = []
|
results = []
|
||||||
iter1 = iter(ds1)
|
iter1 = iter(ds1)
|
||||||
iter2 = iter(ds2)
|
iter2 = iter(ds2)
|
||||||
for _ in range(((num_elements * num_repetitions) // 2) - 1):
|
for _ in range((num_elements * num_repetitions) // 5):
|
||||||
results.append(next(iter1).numpy())
|
results.append(next(iter1).numpy())
|
||||||
for _ in range(((num_elements * num_repetitions) // 2) - 1):
|
for _ in range((num_elements * num_repetitions) // 5):
|
||||||
results.append(next(iter2).numpy())
|
results.append(next(iter2).numpy())
|
||||||
for elem in iter1:
|
for elem in iter1:
|
||||||
results.append(elem.numpy())
|
results.append(elem.numpy())
|
||||||
|
|
Loading…
Reference in New Issue