Parallelizing dataset creation for `experimental_distribute_dataset_from_function".

PiperOrigin-RevId: 328983324
Change-Id: Ifc233024e24bad07429578715e8e1963cea09d31
This commit is contained in:
Chenkai Kuang 2020-08-28 12:03:17 -07:00 committed by TensorFlower Gardener
parent 9c3e4058f9
commit c49645f4b7

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import concurrent
import functools
import sys
@ -57,6 +58,9 @@ from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs import doc_controls
# Maximum number of threads for creating datasets from function.
_DATASET_FROM_FUNCTION_MAX_PARALLELISM = 100
def get_distributed_dataset(dataset,
input_workers,
@ -1741,9 +1745,39 @@ def _create_iterators_per_worker(worker_datasets, input_workers,
return iterators
def _create_datasets_per_worker_with_input_context_in_parallel(
input_contexts, input_workers, dataset_fn):
"""Create device datasets per worker given a dataset function in parallel."""
datasets = [None] * len(input_contexts)
def create_dataset(worker_idx):
worker = input_workers.worker_devices[worker_idx]
with ops.device(worker):
dataset = dataset_fn(input_contexts[worker_idx])
datasets[worker_idx] = dataset
# In the default case (local sync executor + streaming enqueue enabled),
# creation of datasets is blocking. Since dataset creation is thread-safe,
# thus we can parallelize dataset creations using multiple Python threads.
with concurrent.futures.ThreadPoolExecutor(
max_workers=_DATASET_FROM_FUNCTION_MAX_PARALLELISM) as executor:
futures = [
executor.submit(create_dataset, i) for i in range(len(input_contexts))
]
for future in concurrent.futures.as_completed(futures):
future.result() # raise if `create_dataset` raises
return datasets, datasets[-1].element_spec
def _create_datasets_per_worker_with_input_context(input_contexts,
input_workers, dataset_fn):
"""Create device datasets per worker given a dataset function."""
if context.executing_eagerly() and len(input_contexts) > 1:
return _create_datasets_per_worker_with_input_context_in_parallel(
input_contexts, input_workers, dataset_fn)
# For graph mode, creating dataset is just adding nodes to the graph, there is
# no need to parallelize. Plus graph construction is not thread safe.
datasets = []
for i, ctx in enumerate(input_contexts):
worker = input_workers.worker_devices[i]