[tf.data] Adding code for benchmarking map+batch fusion.
PiperOrigin-RevId: 203029765
This commit is contained in:
parent
0a75d72566
commit
8a652f7979
@ -188,6 +188,7 @@ py_test(
|
|||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/contrib/data/python/ops:batching",
|
||||||
"//tensorflow/contrib/data/python/ops:error_ops",
|
"//tensorflow/contrib/data/python/ops:error_ops",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
@ -17,11 +17,16 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.data.python.ops import batching
|
||||||
from tensorflow.contrib.data.python.ops import error_ops
|
from tensorflow.contrib.data.python.ops import error_ops
|
||||||
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -135,5 +140,82 @@ class MapDatasetTest(test.TestCase):
|
|||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
|
||||||
|
class MapDatasetBenchmark(test.Benchmark):
|
||||||
|
|
||||||
|
def benchmarkMapAndBatch(self):
|
||||||
|
small = itertools.product([1, 4], [1, 4], [1, 4], [16, 64], [100])
|
||||||
|
large = itertools.product([16, 64], [16, 64], [16, 64], [256, 1024], [10])
|
||||||
|
|
||||||
|
num_iters = 100
|
||||||
|
|
||||||
|
def benchmark(series):
|
||||||
|
|
||||||
|
for num_calls, inter_op, element_size, batch_size, num_steps in series:
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(
|
||||||
|
np.random.randint(100, size=element_size)).repeat().map(
|
||||||
|
lambda x: x,
|
||||||
|
num_parallel_calls=num_calls).batch(batch_size=batch_size)
|
||||||
|
iterator = dataset.make_one_shot_iterator()
|
||||||
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
|
fused_dataset = dataset_ops.Dataset.from_tensors(
|
||||||
|
np.random.randint(100, size=element_size)).repeat(None).apply(
|
||||||
|
batching.map_and_batch(
|
||||||
|
lambda x: x,
|
||||||
|
num_parallel_calls=num_calls,
|
||||||
|
batch_size=batch_size))
|
||||||
|
fused_iterator = fused_dataset.make_one_shot_iterator()
|
||||||
|
fused_get_next = fused_iterator.get_next()
|
||||||
|
|
||||||
|
fused_deltas = []
|
||||||
|
with session.Session(
|
||||||
|
config=config_pb2.ConfigProto(
|
||||||
|
inter_op_parallelism_threads=inter_op)) as sess:
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
sess.run(fused_get_next)
|
||||||
|
for _ in range(num_iters):
|
||||||
|
start = time.time()
|
||||||
|
for _ in range(num_steps):
|
||||||
|
sess.run(fused_get_next)
|
||||||
|
end = time.time()
|
||||||
|
fused_deltas.append(end - start)
|
||||||
|
|
||||||
|
chained_deltas = []
|
||||||
|
with session.Session(
|
||||||
|
config=config_pb2.ConfigProto(
|
||||||
|
inter_op_parallelism_threads=inter_op)) as sess:
|
||||||
|
for _ in range(5):
|
||||||
|
sess.run(get_next)
|
||||||
|
for _ in range(num_iters):
|
||||||
|
start = time.time()
|
||||||
|
for _ in range(num_steps):
|
||||||
|
sess.run(get_next)
|
||||||
|
end = time.time()
|
||||||
|
chained_deltas.append(end - start)
|
||||||
|
|
||||||
|
chained_wall_time = np.median(chained_deltas) / num_iters
|
||||||
|
fused_wall_time = np.median(fused_deltas) / num_iters
|
||||||
|
print(
|
||||||
|
"batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
|
||||||
|
"element size: %d, chained wall time: %f, fused wall time: %f" %
|
||||||
|
(batch_size, num_calls, inter_op, element_size, chained_wall_time,
|
||||||
|
fused_wall_time))
|
||||||
|
|
||||||
|
self.report_benchmark(
|
||||||
|
iters=num_iters,
|
||||||
|
wall_time=chained_wall_time,
|
||||||
|
name="chained_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d"
|
||||||
|
% (batch_size, num_calls, inter_op, element_size))
|
||||||
|
|
||||||
|
self.report_benchmark(
|
||||||
|
iters=num_iters,
|
||||||
|
wall_time=fused_wall_time,
|
||||||
|
name="fused_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d"
|
||||||
|
% (batch_size, num_calls, inter_op, element_size))
|
||||||
|
|
||||||
|
benchmark(small)
|
||||||
|
benchmark(large)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user