[tf.data service] Autotune uncompression and server-side prefetching.

PiperOrigin-RevId: 334618766
Change-Id: I3a9a7e22f3494738084b94f23a4b275ce3a53ffc
This commit is contained in:
Andrew Audibert 2020-09-30 09:46:30 -07:00 committed by TensorFlower Gardener
parent 54da2d9fd0
commit fb8382b463

View File

@ -238,13 +238,9 @@ def _from_dataset_id(processing_mode,
job_name=job_name,
max_outstanding_requests=max_outstanding_requests,
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
# TODO(b/157105111): Make this an autotuned parallel map when we have a way
# to limit memory usage.
# The value 16 is chosen based on experience with pipelines that require
# more than 8 parallel calls to prevent this stage from being a bottleneck.
dataset = dataset.map(
lambda x: compression_ops.uncompress(x, output_spec=element_spec),
num_parallel_calls=16)
num_parallel_calls=dataset_ops.AUTOTUNE)
# Disable autosharding for shared jobs.
if job_name:
@ -537,11 +533,7 @@ def register_dataset(service, dataset):
dataset = dataset.map(
lambda *x: compression_ops.compress(x),
num_parallel_calls=dataset_ops.AUTOTUNE)
# Prefetch one compressed element to reduce latency when requesting data
# from tf.data workers.
# TODO(b/157105111): Set this to autotune when we have a way to limit
# memory usage
dataset = dataset.prefetch(1)
dataset = dataset.prefetch(dataset_ops.AUTOTUNE)
# Apply options so that the dataset executed in the tf.data service will
# be optimized and support autotuning.
dataset = dataset._apply_options() # pylint: disable=protected-access