diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py index 39790d843ba..01ec155a89c 100644 --- a/tensorflow/python/data/experimental/ops/data_service_ops.py +++ b/tensorflow/python/data/experimental/ops/data_service_ops.py @@ -240,8 +240,11 @@ def _distribute(processing_mode, task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) # TODO(b/157105111): Make this an autotuned parallel map when we have a way # to limit memory usage. + # The value 16 is chosen based on experience with pipelines that require + # more than 8 parallel calls to prevent this stage from being a bottleneck. dataset = dataset.map( - lambda x: compression_ops.uncompress(x, output_spec=uncompressed_spec)) + lambda x: compression_ops.uncompress(x, output_spec=uncompressed_spec), + num_parallel_calls=16) # Disable autosharding for shared jobs. if job_name: diff --git a/tensorflow/python/data/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/kernel_tests/data_service_ops_test.py index d316009ce0c..2356a866d6e 100644 --- a/tensorflow/python/data/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/kernel_tests/data_service_ops_test.py @@ -201,13 +201,18 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): self._new_worker = server_lib.WorkerServer( port=port, master_address=self._master._address, protocol=PROTOCOL) - # The dataset starts over now that we read from the new worker. - for i in range(num_elements): + # There may have been some elements prefetched from the first worker + # before it was stopped. + while True: + val = next(iterator).numpy() + if val == 0: + break + + # The dataset starts over now that we read from the new worker. + # TODO(b/157086991): Iterate until end of sequence when we support + # detecting lost workers. + for i in range(1, num_elements // 2): val = next(iterator).numpy() - if val == midpoint and i != midpoint: - # There may have been one last element prefetched from the first worker - # before it was stopped. - val = next(iterator).numpy() self.assertEqual(i, val) @combinations.generate(test_base.eager_only_combinations()) @@ -248,7 +253,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testSharedJobName(self): - num_elements = 10 + num_elements = 100 master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") @@ -256,7 +261,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): iter1 = iter(ds1) iter2 = iter(ds2) results = [] - for _ in range(3): + for _ in range(num_elements // 5): results.append(next(iter1).numpy()) results.append(next(iter2).numpy()) for elem in iter1: @@ -291,7 +296,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameRepeat(self): - num_elements = 10 + num_elements = 100 num_repetitions = 3 master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) @@ -302,9 +307,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): results = [] iter1 = iter(ds1) iter2 = iter(ds2) - for _ in range(((num_elements * num_repetitions) // 2) - 1): + for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter1).numpy()) - for _ in range(((num_elements * num_repetitions) // 2) - 1): + for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy())