From 8a652f7979135e906c68afd50a8c501b77f2bc81 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Mon, 2 Jul 2018 16:12:49 -0700 Subject: [PATCH] [tf.data] Adding code for benchmarking map+batch fusion. PiperOrigin-RevId: 203029765 --- .../contrib/data/python/kernel_tests/BUILD | 1 + .../kernel_tests/map_dataset_op_test.py | 82 +++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index d81654e039c..c9435eadcd0 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -188,6 +188,7 @@ py_test( "optonly", ], deps = [ + "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/contrib/data/python/ops:error_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 270a2297b4d..a075dfd8b56 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -17,11 +17,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools import os +import time import numpy as np +from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -135,5 +140,82 @@ class MapDatasetTest(test.TestCase): sess.run(get_next) +class MapDatasetBenchmark(test.Benchmark): + + def benchmarkMapAndBatch(self): + small = itertools.product([1, 4], [1, 4], [1, 4], [16, 64], [100]) + large = itertools.product([16, 64], [16, 64], [16, 64], [256, 1024], [10]) + + num_iters = 100 + + def benchmark(series): + + for num_calls, inter_op, element_size, batch_size, num_steps in series: + dataset = dataset_ops.Dataset.from_tensors( + np.random.randint(100, size=element_size)).repeat().map( + lambda x: x, + num_parallel_calls=num_calls).batch(batch_size=batch_size) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + fused_dataset = dataset_ops.Dataset.from_tensors( + np.random.randint(100, size=element_size)).repeat(None).apply( + batching.map_and_batch( + lambda x: x, + num_parallel_calls=num_calls, + batch_size=batch_size)) + fused_iterator = fused_dataset.make_one_shot_iterator() + fused_get_next = fused_iterator.get_next() + + fused_deltas = [] + with session.Session( + config=config_pb2.ConfigProto( + inter_op_parallelism_threads=inter_op)) as sess: + + for _ in range(5): + sess.run(fused_get_next) + for _ in range(num_iters): + start = time.time() + for _ in range(num_steps): + sess.run(fused_get_next) + end = time.time() + fused_deltas.append(end - start) + + chained_deltas = [] + with session.Session( + config=config_pb2.ConfigProto( + inter_op_parallelism_threads=inter_op)) as sess: + for _ in range(5): + sess.run(get_next) + for _ in range(num_iters): + start = time.time() + for _ in range(num_steps): + sess.run(get_next) + end = time.time() + chained_deltas.append(end - start) + + chained_wall_time = np.median(chained_deltas) / num_iters + fused_wall_time = np.median(fused_deltas) / num_iters + print( + "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, " + "element size: %d, chained wall time: %f, fused wall time: %f" % + (batch_size, num_calls, inter_op, element_size, chained_wall_time, + fused_wall_time)) + + self.report_benchmark( + iters=num_iters, + wall_time=chained_wall_time, + name="chained_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d" + % (batch_size, num_calls, inter_op, element_size)) + + self.report_benchmark( + iters=num_iters, + wall_time=fused_wall_time, + name="fused_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d" + % (batch_size, num_calls, inter_op, element_size)) + + benchmark(small) + benchmark(large) + if __name__ == "__main__": test.main()