Improving benchmark for comparing map + batch chaining vs. fusion.

PiperOrigin-RevId: 203789142
This commit is contained in:
Jiri Simsa 2018-07-09 10:35:08 -07:00 committed by TensorFlower Gardener
parent b6e02f68fe
commit 9fea659a48

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import hashlib
import itertools import itertools
import os import os
import time import time
@ -32,9 +33,12 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import io_ops from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.util import compat from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED = 42
class MapDatasetTest(test.TestCase): class MapDatasetTest(test.TestCase):
@ -142,80 +146,123 @@ class MapDatasetTest(test.TestCase):
class MapDatasetBenchmark(test.Benchmark): class MapDatasetBenchmark(test.Benchmark):
# The purpose of this benchmark is to compare the performance of chaining vs
# fusing of the map and batch transformations across various configurations.
#
# NOTE: It is recommended to build the benchmark with
# `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt`
# and execute it on a machine with at least 32 CPU cores.
def benchmarkMapAndBatch(self): def benchmarkMapAndBatch(self):
small = itertools.product([1, 4], [1, 4], [1, 4], [16, 64], [100])
large = itertools.product([16, 64], [16, 64], [16, 64], [256, 1024], [10])
num_iters = 100 # Sequential pipeline configurations.
seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])
def benchmark(series): # Parallel pipeline configuration.
par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
par_batch_size_series = itertools.product([32], [32], [1],
[128, 256, 512, 1024])
par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])
for num_calls, inter_op, element_size, batch_size, num_steps in series: def name(method, label, num_calls, inter_op, element_size, batch_size):
dataset = dataset_ops.Dataset.from_tensors( return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
np.random.randint(100, size=element_size)).repeat().map( method,
lambda x: x, hashlib.sha1(label).hexdigest(),
num_parallel_calls=num_calls).batch(batch_size=batch_size) num_calls,
iterator = dataset.make_one_shot_iterator() inter_op,
get_next = iterator.get_next() element_size,
batch_size,
))
fused_dataset = dataset_ops.Dataset.from_tensors( def benchmark(label, series):
np.random.randint(100, size=element_size)).repeat(None).apply(
batching.map_and_batch( print("%s:" % label)
lambda x: x, for num_calls, inter_op, element_size, batch_size in series:
num_parallel_calls=num_calls,
batch_size=batch_size)) num_iters = 1024 // (
(element_size * batch_size) // min(num_calls, inter_op))
k = 1024 * 1024
dataset = dataset_ops.Dataset.from_tensors((np.random.rand(
element_size, 4 * k), np.random.rand(4 * k, 1))).repeat()
chained_dataset = dataset.map(
math_ops.matmul,
num_parallel_calls=num_calls).batch(batch_size=batch_size)
chained_iterator = chained_dataset.make_one_shot_iterator()
chained_get_next = chained_iterator.get_next()
chained_deltas = []
with session.Session(
config=config_pb2.ConfigProto(
inter_op_parallelism_threads=inter_op,
use_per_session_threads=True)) as sess:
for _ in range(5):
sess.run(chained_get_next.op)
for _ in range(num_iters):
start = time.time()
sess.run(chained_get_next.op)
end = time.time()
chained_deltas.append(end - start)
fused_dataset = dataset = dataset.apply(
batching.map_and_batch(
math_ops.matmul,
num_parallel_calls=num_calls,
batch_size=batch_size))
fused_iterator = fused_dataset.make_one_shot_iterator() fused_iterator = fused_dataset.make_one_shot_iterator()
fused_get_next = fused_iterator.get_next() fused_get_next = fused_iterator.get_next()
fused_deltas = [] fused_deltas = []
with session.Session( with session.Session(
config=config_pb2.ConfigProto( config=config_pb2.ConfigProto(
inter_op_parallelism_threads=inter_op)) as sess: inter_op_parallelism_threads=inter_op,
use_per_session_threads=True)) as sess:
for _ in range(5): for _ in range(5):
sess.run(fused_get_next) sess.run(fused_get_next.op)
for _ in range(num_iters): for _ in range(num_iters):
start = time.time() start = time.time()
for _ in range(num_steps): sess.run(fused_get_next.op)
sess.run(fused_get_next)
end = time.time() end = time.time()
fused_deltas.append(end - start) fused_deltas.append(end - start)
chained_deltas = []
with session.Session(
config=config_pb2.ConfigProto(
inter_op_parallelism_threads=inter_op)) as sess:
for _ in range(5):
sess.run(get_next)
for _ in range(num_iters):
start = time.time()
for _ in range(num_steps):
sess.run(get_next)
end = time.time()
chained_deltas.append(end - start)
chained_wall_time = np.median(chained_deltas) / num_iters
fused_wall_time = np.median(fused_deltas) / num_iters
print( print(
"batch size: %d, num parallel calls: %d, inter-op parallelism: %d, " "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
"element size: %d, chained wall time: %f, fused wall time: %f" % "element size: %d, num iters: %d\nchained wall time: %f (median), "
(batch_size, num_calls, inter_op, element_size, chained_wall_time, "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: "
fused_wall_time)) "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n "
"chained/fused: %.2fx (median), %.2fx (mean)" %
(batch_size, num_calls, inter_op, element_size, num_iters,
np.median(chained_deltas), np.mean(chained_deltas),
np.std(chained_deltas), np.min(chained_deltas),
np.max(chained_deltas), np.median(fused_deltas),
np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas),
np.max(fused_deltas),
np.median(chained_deltas) / np.median(fused_deltas),
np.mean(chained_deltas) / np.mean(fused_deltas)))
self.report_benchmark( self.report_benchmark(
iters=num_iters, iters=num_iters,
wall_time=chained_wall_time, wall_time=np.median(chained_deltas),
name="chained_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d" name=name("chained", label, num_calls, inter_op, element_size,
% (batch_size, num_calls, inter_op, element_size)) batch_size))
self.report_benchmark( self.report_benchmark(
iters=num_iters, iters=num_iters,
wall_time=fused_wall_time, wall_time=np.median(fused_deltas),
name="fused_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d" name=name("fused", label, num_calls, inter_op, element_size,
% (batch_size, num_calls, inter_op, element_size)) batch_size))
benchmark(small) print("")
benchmark(large)
np.random.seed(_NUMPY_RANDOM_SEED)
benchmark("Sequential element size evaluation", seq_elem_size_series)
benchmark("Sequential batch size evaluation", seq_batch_size_series)
benchmark("Parallel element size evaluation", par_elem_size_series)
benchmark("Parallel batch size evaluation", par_batch_size_series)
benchmark("Transformation parallelism evaluation", par_num_calls_series)
benchmark("Threadpool size evaluation", par_inter_op_series)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()