Parallelizing dataset creation for `experimental_distribute_dataset_from_function".
PiperOrigin-RevId: 328983324 Change-Id: Ifc233024e24bad07429578715e8e1963cea09d31
This commit is contained in:
parent
9c3e4058f9
commit
c49645f4b7
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import concurrent
|
||||
import functools
|
||||
import sys
|
||||
|
||||
@ -57,6 +58,9 @@ from tensorflow.python.util.deprecation import deprecated
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
|
||||
# Maximum number of threads for creating datasets from function.
|
||||
_DATASET_FROM_FUNCTION_MAX_PARALLELISM = 100
|
||||
|
||||
|
||||
def get_distributed_dataset(dataset,
|
||||
input_workers,
|
||||
@ -1741,9 +1745,39 @@ def _create_iterators_per_worker(worker_datasets, input_workers,
|
||||
return iterators
|
||||
|
||||
|
||||
def _create_datasets_per_worker_with_input_context_in_parallel(
|
||||
input_contexts, input_workers, dataset_fn):
|
||||
"""Create device datasets per worker given a dataset function in parallel."""
|
||||
datasets = [None] * len(input_contexts)
|
||||
|
||||
def create_dataset(worker_idx):
|
||||
worker = input_workers.worker_devices[worker_idx]
|
||||
with ops.device(worker):
|
||||
dataset = dataset_fn(input_contexts[worker_idx])
|
||||
datasets[worker_idx] = dataset
|
||||
|
||||
# In the default case (local sync executor + streaming enqueue enabled),
|
||||
# creation of datasets is blocking. Since dataset creation is thread-safe,
|
||||
# thus we can parallelize dataset creations using multiple Python threads.
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=_DATASET_FROM_FUNCTION_MAX_PARALLELISM) as executor:
|
||||
futures = [
|
||||
executor.submit(create_dataset, i) for i in range(len(input_contexts))
|
||||
]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
future.result() # raise if `create_dataset` raises
|
||||
|
||||
return datasets, datasets[-1].element_spec
|
||||
|
||||
|
||||
def _create_datasets_per_worker_with_input_context(input_contexts,
|
||||
input_workers, dataset_fn):
|
||||
"""Create device datasets per worker given a dataset function."""
|
||||
if context.executing_eagerly() and len(input_contexts) > 1:
|
||||
return _create_datasets_per_worker_with_input_context_in_parallel(
|
||||
input_contexts, input_workers, dataset_fn)
|
||||
# For graph mode, creating dataset is just adding nodes to the graph, there is
|
||||
# no need to parallelize. Plus graph construction is not thread safe.
|
||||
datasets = []
|
||||
for i, ctx in enumerate(input_contexts):
|
||||
worker = input_workers.worker_devices[i]
|
||||
|
Loading…
Reference in New Issue
Block a user