[tf.data service] Autotune uncompression and server-side prefetching.
PiperOrigin-RevId: 334618766 Change-Id: I3a9a7e22f3494738084b94f23a4b275ce3a53ffc
This commit is contained in:
parent
54da2d9fd0
commit
fb8382b463
@ -238,13 +238,9 @@ def _from_dataset_id(processing_mode,
|
||||
job_name=job_name,
|
||||
max_outstanding_requests=max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
||||
# TODO(b/157105111): Make this an autotuned parallel map when we have a way
|
||||
# to limit memory usage.
|
||||
# The value 16 is chosen based on experience with pipelines that require
|
||||
# more than 8 parallel calls to prevent this stage from being a bottleneck.
|
||||
dataset = dataset.map(
|
||||
lambda x: compression_ops.uncompress(x, output_spec=element_spec),
|
||||
num_parallel_calls=16)
|
||||
num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
|
||||
# Disable autosharding for shared jobs.
|
||||
if job_name:
|
||||
@ -537,11 +533,7 @@ def register_dataset(service, dataset):
|
||||
dataset = dataset.map(
|
||||
lambda *x: compression_ops.compress(x),
|
||||
num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
# Prefetch one compressed element to reduce latency when requesting data
|
||||
# from tf.data workers.
|
||||
# TODO(b/157105111): Set this to autotune when we have a way to limit
|
||||
# memory usage
|
||||
dataset = dataset.prefetch(1)
|
||||
dataset = dataset.prefetch(dataset_ops.AUTOTUNE)
|
||||
# Apply options so that the dataset executed in the tf.data service will
|
||||
# be optimized and support autotuning.
|
||||
dataset = dataset._apply_options() # pylint: disable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user