[tf.data] Add benchmark to compare autotuned MapAndBatchDataset with autotuned ParallelMapDataset followed by ParallelBatchDataset.
PiperOrigin-RevId: 358303549 Change-Id: I49ad3093d72b78ad436d0d9a98f73dbaf4c2eaa8
This commit is contained in:
parent
72684b41eb
commit
810e10d628
@ -23,8 +23,8 @@ import itertools
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -60,6 +60,74 @@ class MapAndBatchBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
warmup=True,
|
||||
name="num_elements_%d_batch_size_%d" % (np.prod(shape), batch_size))
|
||||
|
||||
def _benchmark_series(self, label, series):
|
||||
"""Runs benchmark the given series."""
|
||||
|
||||
# Decides a proper number of iterations according to the inputs.
|
||||
def compute_num_iters(map_num_calls, inter_op, element_size, batch_size):
|
||||
return 1024 // (
|
||||
(element_size * batch_size) //
|
||||
min(12 if map_num_calls == dataset_ops.AUTOTUNE else map_num_calls,
|
||||
inter_op))
|
||||
|
||||
# Makes the dataset based on the inputs.
|
||||
def make_dataset(map_num_calls, element_size, batch_size, batch_num_calls,
|
||||
apply_fusion):
|
||||
k = 1024 * 1024
|
||||
x = constant_op.constant(np.random.rand(element_size, 4 * k))
|
||||
y = constant_op.constant(np.random.rand(4 * k, 1))
|
||||
dataset = dataset_ops.Dataset.range(1000000000000).map(lambda _: (x, y))
|
||||
dataset = dataset.map(math_ops.matmul, num_parallel_calls=map_num_calls)
|
||||
dataset = dataset.batch(
|
||||
batch_size=batch_size, num_parallel_calls=batch_num_calls)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
options.experimental_optimization.map_and_batch_fusion = apply_fusion
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
# Makes the name of the dataset based on the inputs.
|
||||
def make_name(label, map_num_calls, inter_op, element_size, batch_size,
|
||||
batch_num_calls, apply_fusion):
|
||||
map_num_calls_str = ("autotuned" if map_num_calls == dataset_ops.AUTOTUNE
|
||||
else str(map_num_calls))
|
||||
batch_num_calls_str = (
|
||||
"autotuned" if batch_num_calls == dataset_ops.AUTOTUNE else
|
||||
str(1 if batch_num_calls is None else batch_num_calls))
|
||||
name_str = ("%s_id_%s_map_num_calls_%s_batch_num_calls_%s_inter_op_%d"
|
||||
"_elem_size_%d_batch_size_%d")
|
||||
name = (
|
||||
name_str % (
|
||||
"fused" if apply_fusion else "chained",
|
||||
hashlib.sha1((label).encode("utf-8")).hexdigest()[:8],
|
||||
map_num_calls_str,
|
||||
batch_num_calls_str,
|
||||
inter_op,
|
||||
element_size,
|
||||
batch_size,
|
||||
))
|
||||
return name
|
||||
|
||||
for (map_num_calls, inter_op, element_size, batch_size, batch_num_calls,
|
||||
apply_fusion) in series:
|
||||
num_iters = compute_num_iters(map_num_calls, inter_op, element_size,
|
||||
batch_size)
|
||||
dataset = make_dataset(map_num_calls, element_size, batch_size,
|
||||
batch_num_calls, apply_fusion)
|
||||
name = make_name(label, map_num_calls, inter_op, element_size, batch_size,
|
||||
batch_num_calls, apply_fusion)
|
||||
|
||||
session_config = config_pb2.ConfigProto(
|
||||
inter_op_parallelism_threads=inter_op, use_per_session_threads=True)
|
||||
|
||||
self.run_and_report_benchmark(
|
||||
dataset=dataset,
|
||||
iters=num_iters,
|
||||
num_elements=batch_size,
|
||||
warmup=True,
|
||||
session_config=session_config,
|
||||
name=name)
|
||||
|
||||
def benchmark_map_and_batch_chaining_versus_fusing(self):
|
||||
"""Compares the performance of chaining and fusing map and batch.
|
||||
|
||||
@ -69,79 +137,42 @@ class MapAndBatchBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""
|
||||
|
||||
# Sequential pipeline configurations.
|
||||
seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
|
||||
seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])
|
||||
seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16],
|
||||
[None], [False, True])
|
||||
seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64],
|
||||
[None], [False, True])
|
||||
|
||||
# Parallel pipeline configuration.
|
||||
par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
|
||||
par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256],
|
||||
[None], [False, True])
|
||||
par_batch_size_series = itertools.product([32], [32], [1],
|
||||
[128, 256, 512, 1024])
|
||||
par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
|
||||
par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])
|
||||
[128, 256, 512, 1024], [None],
|
||||
[False, True])
|
||||
par_map_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1],
|
||||
[512], [None], [False, True])
|
||||
par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512],
|
||||
[None], [False, True])
|
||||
|
||||
def name(method, label, num_calls, inter_op, element_size, batch_size):
|
||||
return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
|
||||
method,
|
||||
hashlib.sha1((label).encode("utf-8")).hexdigest()[:8],
|
||||
num_calls,
|
||||
inter_op,
|
||||
element_size,
|
||||
batch_size,
|
||||
))
|
||||
|
||||
def benchmark(label, series):
|
||||
"""Runs benchmark the given series."""
|
||||
|
||||
def make_dataset(element_size, num_calls, batch_size): # pylint: disable=missing-docstring
|
||||
k = 1024 * 1024
|
||||
x = constant_op.constant(np.random.rand(element_size, 4 * k))
|
||||
y = constant_op.constant(np.random.rand(4 * k, 1))
|
||||
dataset = dataset_ops.Dataset.range(1000000000000).map(lambda _: (x, y))
|
||||
dataset = dataset.map(
|
||||
math_ops.matmul,
|
||||
num_parallel_calls=num_calls).batch(batch_size=batch_size)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
return dataset.with_options(options)
|
||||
|
||||
for num_calls, inter_op, element_size, batch_size in series:
|
||||
num_iters = 1024 // (
|
||||
(element_size * batch_size) // min(num_calls, inter_op))
|
||||
# By default the chained map().batch() calls will not be fused.
|
||||
chained_dataset = make_dataset(element_size, num_calls, batch_size)
|
||||
session_config = config_pb2.ConfigProto(
|
||||
inter_op_parallelism_threads=inter_op, use_per_session_threads=True)
|
||||
|
||||
self.run_and_report_benchmark(
|
||||
dataset=chained_dataset,
|
||||
iters=num_iters,
|
||||
num_elements=batch_size,
|
||||
warmup=True,
|
||||
session_config=session_config,
|
||||
name=name("chained", label, num_calls, inter_op, element_size,
|
||||
batch_size))
|
||||
|
||||
# Apply an option to the default dataset that will fuse map().batch().
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.map_and_batch_fusion = True
|
||||
fused_dataset = chained_dataset.with_options(options)
|
||||
|
||||
self.run_and_report_benchmark(
|
||||
dataset=fused_dataset,
|
||||
iters=num_iters,
|
||||
num_elements=batch_size,
|
||||
warmup=True,
|
||||
session_config=session_config,
|
||||
name=name("fused", label, num_calls, inter_op, element_size,
|
||||
batch_size))
|
||||
# Autotuned pipeline configuration.
|
||||
fused_versus_chained_series = [
|
||||
(dataset_ops.AUTOTUNE, 32, 1, 16, dataset_ops.AUTOTUNE, False),
|
||||
(dataset_ops.AUTOTUNE, 32, 1, 16, None, True)
|
||||
]
|
||||
|
||||
np.random.seed(_NUMPY_RANDOM_SEED)
|
||||
benchmark("Sequential element size evaluation", seq_elem_size_series)
|
||||
benchmark("Sequential batch size evaluation", seq_batch_size_series)
|
||||
benchmark("Parallel element size evaluation", par_elem_size_series)
|
||||
benchmark("Parallel batch size evaluation", par_batch_size_series)
|
||||
benchmark("Transformation parallelism evaluation", par_num_calls_series)
|
||||
benchmark("Threadpool size evaluation", par_inter_op_series)
|
||||
self._benchmark_series("Sequential element size evaluation",
|
||||
seq_elem_size_series)
|
||||
self._benchmark_series("Sequential batch size evaluation",
|
||||
seq_batch_size_series)
|
||||
self._benchmark_series("Parallel element size evaluation",
|
||||
par_elem_size_series)
|
||||
self._benchmark_series("Parallel batch size evaluation",
|
||||
par_batch_size_series)
|
||||
self._benchmark_series("Transformation parallelism evaluation",
|
||||
par_map_num_calls_series)
|
||||
self._benchmark_series("Threadpool size evaluation", par_inter_op_series)
|
||||
self._benchmark_series("Autotune chained versus fused evaluation",
|
||||
fused_versus_chained_series)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user