[tf.data] Adding code for benchmarking map+batch fusion.
PiperOrigin-RevId: 203029765
This commit is contained in:
parent
0a75d72566
commit
8a652f7979
@ -188,6 +188,7 @@ py_test(
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:batching",
|
||||
"//tensorflow/contrib/data/python/ops:error_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -17,11 +17,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import batching
|
||||
from tensorflow.contrib.data.python.ops import error_ops
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
@ -135,5 +140,82 @@ class MapDatasetTest(test.TestCase):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
class MapDatasetBenchmark(test.Benchmark):
|
||||
|
||||
def benchmarkMapAndBatch(self):
|
||||
small = itertools.product([1, 4], [1, 4], [1, 4], [16, 64], [100])
|
||||
large = itertools.product([16, 64], [16, 64], [16, 64], [256, 1024], [10])
|
||||
|
||||
num_iters = 100
|
||||
|
||||
def benchmark(series):
|
||||
|
||||
for num_calls, inter_op, element_size, batch_size, num_steps in series:
|
||||
dataset = dataset_ops.Dataset.from_tensors(
|
||||
np.random.randint(100, size=element_size)).repeat().map(
|
||||
lambda x: x,
|
||||
num_parallel_calls=num_calls).batch(batch_size=batch_size)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
|
||||
fused_dataset = dataset_ops.Dataset.from_tensors(
|
||||
np.random.randint(100, size=element_size)).repeat(None).apply(
|
||||
batching.map_and_batch(
|
||||
lambda x: x,
|
||||
num_parallel_calls=num_calls,
|
||||
batch_size=batch_size))
|
||||
fused_iterator = fused_dataset.make_one_shot_iterator()
|
||||
fused_get_next = fused_iterator.get_next()
|
||||
|
||||
fused_deltas = []
|
||||
with session.Session(
|
||||
config=config_pb2.ConfigProto(
|
||||
inter_op_parallelism_threads=inter_op)) as sess:
|
||||
|
||||
for _ in range(5):
|
||||
sess.run(fused_get_next)
|
||||
for _ in range(num_iters):
|
||||
start = time.time()
|
||||
for _ in range(num_steps):
|
||||
sess.run(fused_get_next)
|
||||
end = time.time()
|
||||
fused_deltas.append(end - start)
|
||||
|
||||
chained_deltas = []
|
||||
with session.Session(
|
||||
config=config_pb2.ConfigProto(
|
||||
inter_op_parallelism_threads=inter_op)) as sess:
|
||||
for _ in range(5):
|
||||
sess.run(get_next)
|
||||
for _ in range(num_iters):
|
||||
start = time.time()
|
||||
for _ in range(num_steps):
|
||||
sess.run(get_next)
|
||||
end = time.time()
|
||||
chained_deltas.append(end - start)
|
||||
|
||||
chained_wall_time = np.median(chained_deltas) / num_iters
|
||||
fused_wall_time = np.median(fused_deltas) / num_iters
|
||||
print(
|
||||
"batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
|
||||
"element size: %d, chained wall time: %f, fused wall time: %f" %
|
||||
(batch_size, num_calls, inter_op, element_size, chained_wall_time,
|
||||
fused_wall_time))
|
||||
|
||||
self.report_benchmark(
|
||||
iters=num_iters,
|
||||
wall_time=chained_wall_time,
|
||||
name="chained_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d"
|
||||
% (batch_size, num_calls, inter_op, element_size))
|
||||
|
||||
self.report_benchmark(
|
||||
iters=num_iters,
|
||||
wall_time=fused_wall_time,
|
||||
name="fused_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d"
|
||||
% (batch_size, num_calls, inter_op, element_size))
|
||||
|
||||
benchmark(small)
|
||||
benchmark(large)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user