[tf.data] Add benchmark to compare autotuned MapAndBatchDataset with autotuned ParallelMapDataset followed by ParallelBatchDataset.

PiperOrigin-RevId: 358303549
Change-Id: I49ad3093d72b78ad436d0d9a98f73dbaf4c2eaa8
This commit is contained in:
Jay Shi 2021-02-18 17:40:30 -08:00 committed by TensorFlower Gardener
parent 72684b41eb
commit 810e10d628

View File

@ -23,8 +23,8 @@ import itertools
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import math_ops
@ -60,6 +60,74 @@ class MapAndBatchBenchmark(benchmark_base.DatasetBenchmarkBase):
warmup=True,
name="num_elements_%d_batch_size_%d" % (np.prod(shape), batch_size))
def _benchmark_series(self, label, series):
"""Runs benchmark the given series."""
# Decides a proper number of iterations according to the inputs.
def compute_num_iters(map_num_calls, inter_op, element_size, batch_size):
return 1024 // (
(element_size * batch_size) //
min(12 if map_num_calls == dataset_ops.AUTOTUNE else map_num_calls,
inter_op))
# Makes the dataset based on the inputs.
def make_dataset(map_num_calls, element_size, batch_size, batch_num_calls,
apply_fusion):
k = 1024 * 1024
x = constant_op.constant(np.random.rand(element_size, 4 * k))
y = constant_op.constant(np.random.rand(4 * k, 1))
dataset = dataset_ops.Dataset.range(1000000000000).map(lambda _: (x, y))
dataset = dataset.map(math_ops.matmul, num_parallel_calls=map_num_calls)
dataset = dataset.batch(
batch_size=batch_size, num_parallel_calls=batch_num_calls)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.map_and_batch_fusion = apply_fusion
dataset = dataset.with_options(options)
return dataset
# Makes the name of the dataset based on the inputs.
def make_name(label, map_num_calls, inter_op, element_size, batch_size,
batch_num_calls, apply_fusion):
map_num_calls_str = ("autotuned" if map_num_calls == dataset_ops.AUTOTUNE
else str(map_num_calls))
batch_num_calls_str = (
"autotuned" if batch_num_calls == dataset_ops.AUTOTUNE else
str(1 if batch_num_calls is None else batch_num_calls))
name_str = ("%s_id_%s_map_num_calls_%s_batch_num_calls_%s_inter_op_%d"
"_elem_size_%d_batch_size_%d")
name = (
name_str % (
"fused" if apply_fusion else "chained",
hashlib.sha1((label).encode("utf-8")).hexdigest()[:8],
map_num_calls_str,
batch_num_calls_str,
inter_op,
element_size,
batch_size,
))
return name
for (map_num_calls, inter_op, element_size, batch_size, batch_num_calls,
apply_fusion) in series:
num_iters = compute_num_iters(map_num_calls, inter_op, element_size,
batch_size)
dataset = make_dataset(map_num_calls, element_size, batch_size,
batch_num_calls, apply_fusion)
name = make_name(label, map_num_calls, inter_op, element_size, batch_size,
batch_num_calls, apply_fusion)
session_config = config_pb2.ConfigProto(
inter_op_parallelism_threads=inter_op, use_per_session_threads=True)
self.run_and_report_benchmark(
dataset=dataset,
iters=num_iters,
num_elements=batch_size,
warmup=True,
session_config=session_config,
name=name)
def benchmark_map_and_batch_chaining_versus_fusing(self):
"""Compares the performance of chaining and fusing map and batch.
@ -69,79 +137,42 @@ class MapAndBatchBenchmark(benchmark_base.DatasetBenchmarkBase):
"""
# Sequential pipeline configurations.
seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])
seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16],
[None], [False, True])
seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64],
[None], [False, True])
# Parallel pipeline configuration.
par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256],
[None], [False, True])
par_batch_size_series = itertools.product([32], [32], [1],
[128, 256, 512, 1024])
par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])
[128, 256, 512, 1024], [None],
[False, True])
par_map_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1],
[512], [None], [False, True])
par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512],
[None], [False, True])
def name(method, label, num_calls, inter_op, element_size, batch_size):
return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
method,
hashlib.sha1((label).encode("utf-8")).hexdigest()[:8],
num_calls,
inter_op,
element_size,
batch_size,
))
def benchmark(label, series):
"""Runs benchmark the given series."""
def make_dataset(element_size, num_calls, batch_size): # pylint: disable=missing-docstring
k = 1024 * 1024
x = constant_op.constant(np.random.rand(element_size, 4 * k))
y = constant_op.constant(np.random.rand(4 * k, 1))
dataset = dataset_ops.Dataset.range(1000000000000).map(lambda _: (x, y))
dataset = dataset.map(
math_ops.matmul,
num_parallel_calls=num_calls).batch(batch_size=batch_size)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
return dataset.with_options(options)
for num_calls, inter_op, element_size, batch_size in series:
num_iters = 1024 // (
(element_size * batch_size) // min(num_calls, inter_op))
# By default the chained map().batch() calls will not be fused.
chained_dataset = make_dataset(element_size, num_calls, batch_size)
session_config = config_pb2.ConfigProto(
inter_op_parallelism_threads=inter_op, use_per_session_threads=True)
self.run_and_report_benchmark(
dataset=chained_dataset,
iters=num_iters,
num_elements=batch_size,
warmup=True,
session_config=session_config,
name=name("chained", label, num_calls, inter_op, element_size,
batch_size))
# Apply an option to the default dataset that will fuse map().batch().
options = dataset_ops.Options()
options.experimental_optimization.map_and_batch_fusion = True
fused_dataset = chained_dataset.with_options(options)
self.run_and_report_benchmark(
dataset=fused_dataset,
iters=num_iters,
num_elements=batch_size,
warmup=True,
session_config=session_config,
name=name("fused", label, num_calls, inter_op, element_size,
batch_size))
# Autotuned pipeline configuration.
fused_versus_chained_series = [
(dataset_ops.AUTOTUNE, 32, 1, 16, dataset_ops.AUTOTUNE, False),
(dataset_ops.AUTOTUNE, 32, 1, 16, None, True)
]
np.random.seed(_NUMPY_RANDOM_SEED)
benchmark("Sequential element size evaluation", seq_elem_size_series)
benchmark("Sequential batch size evaluation", seq_batch_size_series)
benchmark("Parallel element size evaluation", par_elem_size_series)
benchmark("Parallel batch size evaluation", par_batch_size_series)
benchmark("Transformation parallelism evaluation", par_num_calls_series)
benchmark("Threadpool size evaluation", par_inter_op_series)
self._benchmark_series("Sequential element size evaluation",
seq_elem_size_series)
self._benchmark_series("Sequential batch size evaluation",
seq_batch_size_series)
self._benchmark_series("Parallel element size evaluation",
par_elem_size_series)
self._benchmark_series("Parallel batch size evaluation",
par_batch_size_series)
self._benchmark_series("Transformation parallelism evaluation",
par_map_num_calls_series)
self._benchmark_series("Threadpool size evaluation", par_inter_op_series)
self._benchmark_series("Autotune chained versus fused evaluation",
fused_versus_chained_series)
if __name__ == "__main__":