[tf.data] Clean up and standardize benchmarks.

PiperOrigin-RevId: 232519626
This commit is contained in:
Rachel Lim 2019-02-05 11:04:08 -08:00 committed by TensorFlower Gardener
parent cc79252c0b
commit 36e8b05115
7 changed files with 209 additions and 367 deletions

View File

@ -17,15 +17,23 @@ py_test(
], ],
) )
py_library(
name = "benchmark_base",
srcs = ["benchmark_base.py"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
py_test( py_test(
name = "batch_benchmark", name = "batch_benchmark",
srcs = ["batch_benchmark.py"], srcs = ["batch_benchmark.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/python:array_ops", ":benchmark_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:session",
"//tensorflow/python:sparse_tensor", "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy", "//third_party/py/numpy",
@ -37,12 +45,8 @@ py_test(
srcs = ["filter_benchmark.py"], srcs = ["filter_benchmark.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/python:array_ops", ":benchmark_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
], ],
) )
@ -51,9 +55,7 @@ py_test(
srcs = ["from_tensor_slices_benchmark.py"], srcs = ["from_tensor_slices_benchmark.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/python:client_testlib", ":benchmark_base",
"//tensorflow/python:errors",
"//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy", "//third_party/py/numpy",
], ],
@ -64,6 +66,7 @@ py_test(
srcs = ["list_files_benchmark.py"], srcs = ["list_files_benchmark.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":benchmark_base",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
@ -78,11 +81,8 @@ py_test(
srcs = ["map_benchmark.py"], srcs = ["map_benchmark.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/python:client_testlib", ":benchmark_base",
"//tensorflow/python:framework_ops",
"//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
], ],
) )
@ -91,8 +91,7 @@ py_test(
srcs = ["range_benchmark.py"], srcs = ["range_benchmark.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/python:client_testlib", ":benchmark_base",
"//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:dataset_ops",
], ],
) )

View File

@ -17,70 +17,37 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import time
import numpy as np import numpy as np
from tensorflow.python.client import session from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
# TODO(b/119837791): Add eager benchmarks. class BatchBenchmark(benchmark_base.DatasetBenchmarkBase):
class BatchBenchmark(test.Benchmark):
"""Benchmarks for `tf.data.Dataset.batch()`.""" """Benchmarks for `tf.data.Dataset.batch()`."""
def benchmarkBatchSparse(self): def benchmark_batch_sparse(self):
non_zeros_per_row_values = [0, 1, 5, 10, 100] non_zeros_per_row_values = [0, 1, 5, 10, 100]
batch_size_values = [1, 32, 64, 128, 1024] batch_size_values = [1, 32, 64, 128, 1024]
sparse_placeholder = array_ops.sparse_placeholder(dtype=dtypes.int64)
batch_size_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[])
dataset = dataset_ops.Dataset.from_tensors(sparse_placeholder).repeat(
).batch(batch_size_placeholder)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
for non_zeros_per_row in non_zeros_per_row_values: for non_zeros_per_row in non_zeros_per_row_values:
sparse_value = sparse_tensor.SparseTensorValue( tensor = sparse_tensor.SparseTensor(
indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis], indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis],
values=np.arange(non_zeros_per_row, dtype=np.int64), values=np.arange(non_zeros_per_row, dtype=np.int64),
dense_shape=[1000]) dense_shape=[1000])
for batch_size in batch_size_values: for batch_size in batch_size_values:
dataset = dataset_ops.Dataset.from_tensors(tensor).repeat().batch(
with session.Session() as sess: batch_size)
sess.run(iterator.initializer, feed_dict={ self.run_and_report_benchmark(
sparse_placeholder: sparse_value, dataset,
batch_size_placeholder: batch_size}) num_elements=100000 // batch_size,
# Run five steps to warm up the session caches before taking the iters=1,
# first measurement. name="sparse_num_elements_%d_batch_size_%d" % (non_zeros_per_row,
for _ in range(5): batch_size))
sess.run(next_element.indices.op)
deltas = []
for _ in range(100):
start = time.time()
for _ in range(100):
sess.run(next_element.indices.op)
end = time.time()
deltas.append(end - start)
median_wall_time = np.median(deltas) / 100.0
self.report_benchmark(
iters=10000,
wall_time=median_wall_time,
name="sparse_num_elements_%d_batch_size_%d" %
(non_zeros_per_row, batch_size))
if __name__ == "__main__": if __name__ == "__main__":
test.main() benchmark_base.test.main()

View File

@ -0,0 +1,92 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test utilities for tf.data benchmarking functionality."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.platform import test
# TODO(b/119837791): Add eager benchmarks.
class DatasetBenchmarkBase(test.Benchmark):
"""Base class for dataset benchmarks."""
def run_benchmark(self, dataset, num_elements, iters=1):
"""Benchmarks the dataset.
Runs the dataset `iters` times. In each iteration, the benchmark measures
the time it takes to go through `num_elements` elements of the dataset.
Args:
dataset: Dataset to benchmark.
num_elements: Number of dataset elements to iterate through each benchmark
iteration.
iters: Number of times to repeat the timing.
Returns:
A float, representing the per-element wall time of the dataset in seconds.
This is the median time (with respect to `iters`) it takes for the dataset
to go through `num_elements` elements, divided by `num_elements.`
"""
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
# NOTE: We use `dataset.skip()` to perform the iterations in C++, avoiding
# the overhead of multiple `session.run()` calls. Note that this relies on
# the underlying implementation of `skip`: if it is optimized in the future,
# we will have to change this code.
dataset = dataset.skip(num_elements - 1)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
next_element = nest.flatten(next_element)[0]
deltas = []
for _ in range(iters):
with session.Session() as sess:
# Run once to warm up the session caches.
sess.run(iterator.initializer)
sess.run(next_element)
sess.run(iterator.initializer)
start = time.time()
sess.run(next_element.op)
end = time.time()
deltas.append(end - start)
return np.median(deltas) / float(num_elements)
def run_and_report_benchmark(self,
dataset,
num_elements,
name,
iters=5,
extras=None):
# Measure the per-element wall time.
wall_time = self.run_benchmark(dataset, num_elements, iters)
if extras is None:
extras = {}
extras["elements_per_second"] = 1 / wall_time
extras["num_elements"] = num_elements
# 'mode' represents the mechanism used for iterating over dataset elements.
name = "%s_mode_cpp" % name
self.report_benchmark(
wall_time=wall_time, iters=iters, name=name, extras=extras)

View File

@ -17,51 +17,26 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import time from tensorflow.python.data.benchmarks import benchmark_base
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
# TODO(b/119837791): Add eager benchmarks. # TODO(b/119837791): Add eager benchmarks.
class FilterBenchmark(test.Benchmark): class FilterBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for `tf.data.Dataset.filter()`.""" """Benchmarks for `tf.data.Dataset.filter()`."""
def _benchmark(self, predicate, name): def _benchmark(self, predicate, name):
with ops.Graph().as_default(): dataset = (
dataset = ( dataset_ops.Dataset.from_tensors(True).repeat(None).filter(predicate))
dataset_ops.Dataset.from_tensors(True).repeat(None).filter(predicate)) self.run_and_report_benchmark(dataset, num_elements=100000, name=name)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
iterator = dataset_ops.make_one_shot_iterator(dataset)
next_element = iterator.get_next()
with session.Session() as sess: def benchmark_simple_function(self):
for _ in range(5):
sess.run(next_element.op)
deltas = []
for _ in range(100):
start = time.time()
for _ in range(100):
sess.run(next_element.op)
end = time.time()
deltas.append(end - start)
median_wall_time = np.median(deltas) / 100
self.report_benchmark(iters=100, wall_time=median_wall_time, name=name)
def benchmarkSimpleFunction(self):
self._benchmark(array_ops.identity, "simple_function") self._benchmark(array_ops.identity, "simple_function")
def benchmarkReturnComponentOptimization(self): def benchmark_return_component_optimization(self):
self._benchmark(lambda x: x, "return_component") self._benchmark(lambda x: x, "return_component")
if __name__ == "__main__": if __name__ == "__main__":
test.main() benchmark_base.test.main()

View File

@ -17,170 +17,70 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import time
import numpy as np import numpy as np
from tensorflow.python.client import session from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
# TODO(b/119837791): Add eager benchmarks. # TODO(b/119837791): Add eager benchmarks.
class FromTensorSlicesBenchmark(test.Benchmark): class FromTensorSlicesBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for `tf.data.Dataset.from_tensor_slices()`.""" """Benchmarks for `tf.data.Dataset.from_tensor_slices()`."""
def benchmarkSliceRepeatBatch(self): def benchmark_slice_repeat_batch(self):
input_size = 10000 input_size = 10000
batch_size = 100 batch_size = 100
num_epochs = 100 num_epochs = 100
num_elements = input_size * num_epochs // batch_size
input_data = np.random.randn(input_size) input_data = np.random.randn(input_size)
dataset = ( dataset = (
dataset_ops.Dataset.from_tensor_slices(input_data) dataset_ops.Dataset.from_tensor_slices(input_data).repeat(
.repeat(num_epochs + 1).batch(batch_size)) num_epochs).batch(batch_size))
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
with session.Session() as sess: self.run_and_report_benchmark(
sess.run(iterator.initializer) dataset,
# Run one whole epoch to burn in the computation. num_elements=num_elements,
for _ in range(input_size // batch_size):
sess.run(next_element)
deltas = []
try:
while True:
start = time.time()
sess.run(next_element)
deltas.append(time.time() - start)
except errors.OutOfRangeError:
pass
median_wall_time = np.median(deltas)
self.report_benchmark(
iters=len(deltas),
wall_time=median_wall_time,
name="slice_repeat_batch_input_%d_batch_%d" % (input_size, batch_size)) name="slice_repeat_batch_input_%d_batch_%d" % (input_size, batch_size))
def benchmarkSliceRepeatBatchCallable(self): def benchmark_reshape_slice_repeat(self):
input_size = 10000 input_size = 10000
batch_size = 100 reshape_dim = [100, 100]
num_epochs = 100 num_epochs = 100
num_elements = num_epochs * reshape_dim[0]
input_data = np.random.randn(input_size) input_data = np.random.randn(input_size)
dataset = ( dataset = (
dataset_ops.Dataset.from_tensor_slices(input_data) dataset_ops.Dataset.from_tensor_slices(
.repeat(num_epochs + 1).batch(batch_size)) input_data.reshape(*reshape_dim)).repeat(num_epochs))
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
with session.Session() as sess: self.run_and_report_benchmark(
sess.run(iterator.initializer) dataset,
get_next_element = sess.make_callable(next_element) num_elements=num_elements,
# Run one whole epoch to burn in the computation. name="reshape_slice_repeat_input_%d" % input_size,
for _ in range(input_size // batch_size): )
get_next_element()
deltas = []
try:
while True:
start = time.time()
get_next_element()
deltas.append(time.time() - start)
except errors.OutOfRangeError:
pass
median_wall_time = np.median(deltas) def benchmark_slice_batch_cache_repeat(self):
self.report_benchmark(
iters=len(deltas),
wall_time=median_wall_time,
name="slice_repeat_batch_callable_input_%d_batch_%d" %
(input_size, batch_size))
def benchmarkReshapeSliceRepeatCallable(self):
input_size = 10000 input_size = 10000
batch_size = 100 batch_size = 100
num_epochs = 100 num_epochs = 100
num_elements = input_size * num_epochs // batch_size
input_data = np.random.randn(input_size) input_data = np.random.randn(input_size)
dataset = ( dataset = (
dataset_ops.Dataset.from_tensor_slices(input_data.reshape(100, 100)) dataset_ops.Dataset.from_tensor_slices(input_data).batch(
.repeat(num_epochs + 1)) batch_size).cache().repeat(num_epochs))
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
with session.Session() as sess: self.run_and_report_benchmark(
sess.run(iterator.initializer) dataset,
get_next_element = sess.make_callable(next_element) num_elements=num_elements,
# Run one whole epoch to burn in the computation. name="slice_batch_cache_repeat_input_%d_batch_%d" % (input_size,
for _ in range(input_size // batch_size): batch_size))
get_next_element()
deltas = []
try:
while True:
start = time.time()
get_next_element()
deltas.append(time.time() - start)
except errors.OutOfRangeError:
pass
median_wall_time = np.median(deltas)
self.report_benchmark(
iters=len(deltas),
wall_time=median_wall_time,
name="reshape_slice_repeat_callable_input_%d_batch_%d" %
(input_size, batch_size))
def benchmarkSliceBatchCacheRepeatCallable(self):
input_size = 10000
batch_size = 100
num_epochs = 100
input_data = np.random.randn(input_size)
dataset = (
dataset_ops.Dataset.from_tensor_slices(input_data).batch(batch_size)
.cache().repeat(num_epochs + 1))
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
with session.Session() as sess:
sess.run(iterator.initializer)
get_next_element = sess.make_callable(next_element)
# Run one whole epoch to burn in the computation.
for _ in range(input_size // batch_size):
get_next_element()
deltas = []
try:
while True:
start = time.time()
get_next_element()
deltas.append(time.time() - start)
except errors.OutOfRangeError:
pass
median_wall_time = np.median(deltas)
self.report_benchmark(
iters=len(deltas),
wall_time=median_wall_time,
name="slice_batch_cache_repeat_callable_input_%d_batch_%d" %
(input_size, batch_size))
if __name__ == "__main__": if __name__ == "__main__":
test.main() benchmark_base.test.main()

View File

@ -17,114 +17,51 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import time from tensorflow.python.data.benchmarks import benchmark_base
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
# TODO(b/119837791): Add eager benchmarks. # TODO(b/119837791): Add eager benchmarks.
class MapBenchmark(test.Benchmark): class MapBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Bechmarks for `tf.data.Dataset.map()`.""" """Benchmarks for `tf.data.Dataset.map()`."""
def benchmark_chain_of_maps(self):
def benchmark_helper(chain_length, map_fn, use_inter_op_parallelism, label):
dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
for _ in range(chain_length):
dataset = dataset_ops.MapDataset(
dataset, map_fn, use_inter_op_parallelism=use_inter_op_parallelism)
self.run_and_report_benchmark(
dataset,
num_elements=10000,
name="chain_length_%d%s" % (chain_length, label))
def benchmarkChainOfMaps(self):
chain_lengths = [0, 1, 2, 5, 10, 20, 50] chain_lengths = [0, 1, 2, 5, 10, 20, 50]
for chain_length in chain_lengths: for chain_length in chain_lengths:
for mode in ["general", "single-threaded", "short-circuit"]: benchmark_helper(chain_length, lambda x: x + 1, True, "")
if mode == "general": benchmark_helper(chain_length, lambda x: x + 1, False, "_single_threaded")
map_fn = lambda x: x + 1 benchmark_helper(chain_length, lambda x: x, True, "_short_circuit")
use_inter_op_parallelism = True
benchmark_label = ""
if mode == "single-threaded":
map_fn = lambda x: x + 1
use_inter_op_parallelism = False
benchmark_label = "_single_threaded"
if mode == "short-circuit":
map_fn = lambda x: x
use_inter_op_parallelism = True # should not have any significance
benchmark_label = "_short_circuit"
with ops.Graph().as_default(): def benchmark_map_fan_out(self):
dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
for _ in range(chain_length):
dataset = dataset_ops.MapDataset(
dataset,
map_fn,
use_inter_op_parallelism=use_inter_op_parallelism)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
iterator = dataset_ops.make_one_shot_iterator(dataset)
next_element = iterator.get_next()
with session.Session() as sess:
for _ in range(5):
sess.run(next_element.op)
deltas = []
for _ in range(100):
start = time.time()
for _ in range(100):
sess.run(next_element.op)
end = time.time()
deltas.append(end - start)
median_wall_time = np.median(deltas) / 100
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
name="chain_length_%d%s" % (chain_length, benchmark_label))
def benchmarkMapFanOut(self):
fan_outs = [1, 2, 5, 10, 20, 50, 100] fan_outs = [1, 2, 5, 10, 20, 50, 100]
def benchmark_helper(fan_out, map_fn, use_inter_op_parallelism, label):
dataset = dataset_ops.Dataset.from_tensors(
tuple(0 for _ in range(fan_out))).repeat(None)
dataset = dataset_ops.MapDataset(
dataset, map_fn, use_inter_op_parallelism=use_inter_op_parallelism)
self.run_and_report_benchmark(
dataset,
num_elements=10000,
name="fan_out_%d%s" % (fan_out, label))
for fan_out in fan_outs: for fan_out in fan_outs:
for mode in ["general", "single-threaded", "short-circuit"]: benchmark_helper(fan_out, lambda *xs: [x + 1 for x in xs], True, "")
if mode == "general": benchmark_helper(fan_out, lambda *xs: [x + 1 for x in xs], False,
map_fn = lambda *xs: [x + 1 for x in xs] "_single_threaded")
use_inter_op_parallelism = True benchmark_helper(fan_out, lambda *xs: xs, True, "_short_circuit")
benchmark_label = ""
if mode == "single-threaded":
map_fn = lambda *xs: [x + 1 for x in xs]
use_inter_op_parallelism = False
benchmark_label = "_single_threaded"
if mode == "short-circuit":
map_fn = lambda *xs: xs
use_inter_op_parallelism = True # should not have any significance
benchmark_label = "_short_circuit"
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(
tuple(0 for _ in range(fan_out))).repeat(None)
dataset = dataset_ops.MapDataset(
dataset,
map_fn,
use_inter_op_parallelism=use_inter_op_parallelism)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
iterator = dataset_ops.make_one_shot_iterator(dataset)
next_element = iterator.get_next()
with session.Session() as sess:
for _ in range(5):
sess.run(next_element[0].op)
deltas = []
for _ in range(100):
start = time.time()
for _ in range(100):
sess.run(next_element[0].op)
end = time.time()
deltas.append(end - start)
median_wall_time = np.median(deltas) / 100
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
name="fan_out_%d%s" % (fan_out, benchmark_label))
if __name__ == "__main__": if __name__ == "__main__":
test.main() benchmark_base.test.main()

View File

@ -17,54 +17,26 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import time from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
_NUMPY_RANDOM_SEED = 42
class RangeBenchmark(test.Benchmark): class RangeBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for `tf.data.Dataset.range()`.""" """Benchmarks for `tf.data.Dataset.range()`."""
def _benchmarkRangeHelper(self, modeling_enabled): def benchmark_range(self):
num_elements = 10000000 if modeling_enabled else 50000000
# Use `Dataset.skip()` and `Dataset.take()` to perform the iteration in
# C++, and focus on the minimal overheads (excluding Python invocation
# costs).
dataset = dataset_ops.Dataset.range(num_elements).skip(
num_elements - 1).take(1)
options = dataset_ops.Options()
options.experimental_autotune = modeling_enabled
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
with session.Session() as sess:
# Run once to warm up the session caches.
sess.run(iterator.initializer)
sess.run(next_element)
# Run once for timing.
sess.run(iterator.initializer)
start = time.time()
sess.run(next_element)
end = time.time()
time_per_element = (end - start) / num_elements
self.report_benchmark(
iters=num_elements,
wall_time=time_per_element,
name="modeling_%s" % ("on" if modeling_enabled else "off"))
def benchmarkRange(self):
for modeling_enabled in [False, True]: for modeling_enabled in [False, True]:
self._benchmarkRangeHelper(modeling_enabled) num_elements = 10000000 if modeling_enabled else 50000000
options = dataset_ops.Options()
options.experimental_autotune = modeling_enabled
dataset = dataset_ops.Dataset.range(num_elements)
dataset = dataset.with_options(options)
self.run_and_report_benchmark(
dataset,
num_elements=num_elements,
name="modeling_%s" % ("on" if modeling_enabled else "off"))
if __name__ == "__main__": if __name__ == "__main__":
test.main() benchmark_base.test.main()