[tf.data] Adding code for benchmarking map+batch fusion.

PiperOrigin-RevId: 203029765
This commit is contained in:
Jiri Simsa 2018-07-02 16:12:49 -07:00 committed by TensorFlower Gardener
parent 0a75d72566
commit 8a652f7979
2 changed files with 83 additions and 0 deletions
tensorflow/contrib/data/python/kernel_tests

View File

@ -188,6 +188,7 @@ py_test(
"optonly",
],
deps = [
"//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/contrib/data/python/ops:error_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",

View File

@ -17,11 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import os
import time
import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@ -135,5 +140,82 @@ class MapDatasetTest(test.TestCase):
sess.run(get_next)
class MapDatasetBenchmark(test.Benchmark):
def benchmarkMapAndBatch(self):
small = itertools.product([1, 4], [1, 4], [1, 4], [16, 64], [100])
large = itertools.product([16, 64], [16, 64], [16, 64], [256, 1024], [10])
num_iters = 100
def benchmark(series):
for num_calls, inter_op, element_size, batch_size, num_steps in series:
dataset = dataset_ops.Dataset.from_tensors(
np.random.randint(100, size=element_size)).repeat().map(
lambda x: x,
num_parallel_calls=num_calls).batch(batch_size=batch_size)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
fused_dataset = dataset_ops.Dataset.from_tensors(
np.random.randint(100, size=element_size)).repeat(None).apply(
batching.map_and_batch(
lambda x: x,
num_parallel_calls=num_calls,
batch_size=batch_size))
fused_iterator = fused_dataset.make_one_shot_iterator()
fused_get_next = fused_iterator.get_next()
fused_deltas = []
with session.Session(
config=config_pb2.ConfigProto(
inter_op_parallelism_threads=inter_op)) as sess:
for _ in range(5):
sess.run(fused_get_next)
for _ in range(num_iters):
start = time.time()
for _ in range(num_steps):
sess.run(fused_get_next)
end = time.time()
fused_deltas.append(end - start)
chained_deltas = []
with session.Session(
config=config_pb2.ConfigProto(
inter_op_parallelism_threads=inter_op)) as sess:
for _ in range(5):
sess.run(get_next)
for _ in range(num_iters):
start = time.time()
for _ in range(num_steps):
sess.run(get_next)
end = time.time()
chained_deltas.append(end - start)
chained_wall_time = np.median(chained_deltas) / num_iters
fused_wall_time = np.median(fused_deltas) / num_iters
print(
"batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
"element size: %d, chained wall time: %f, fused wall time: %f" %
(batch_size, num_calls, inter_op, element_size, chained_wall_time,
fused_wall_time))
self.report_benchmark(
iters=num_iters,
wall_time=chained_wall_time,
name="chained_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d"
% (batch_size, num_calls, inter_op, element_size))
self.report_benchmark(
iters=num_iters,
wall_time=fused_wall_time,
name="fused_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d"
% (batch_size, num_calls, inter_op, element_size))
benchmark(small)
benchmark(large)
if __name__ == "__main__":
test.main()