[tf.data] Clean up and standardize benchmarks.
PiperOrigin-RevId: 232519626
This commit is contained in:
parent
cc79252c0b
commit
36e8b05115
@ -17,15 +17,23 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "benchmark_base",
|
||||
srcs = ["benchmark_base.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "batch_benchmark",
|
||||
srcs = ["batch_benchmark.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:session",
|
||||
":benchmark_base",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
@ -37,12 +45,8 @@ py_test(
|
||||
srcs = ["filter_benchmark.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:session",
|
||||
":benchmark_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -51,9 +55,7 @@ py_test(
|
||||
srcs = ["from_tensor_slices_benchmark.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:session",
|
||||
":benchmark_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
@ -64,6 +66,7 @@ py_test(
|
||||
srcs = ["list_files_benchmark.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":benchmark_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
@ -78,11 +81,8 @@ py_test(
|
||||
srcs = ["map_benchmark.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:session",
|
||||
":benchmark_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -91,8 +91,7 @@ py_test(
|
||||
srcs = ["range_benchmark.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:session",
|
||||
":benchmark_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
@ -17,70 +17,37 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# TODO(b/119837791): Add eager benchmarks.
|
||||
class BatchBenchmark(test.Benchmark):
|
||||
class BatchBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""Benchmarks for `tf.data.Dataset.batch()`."""
|
||||
|
||||
def benchmarkBatchSparse(self):
|
||||
def benchmark_batch_sparse(self):
|
||||
non_zeros_per_row_values = [0, 1, 5, 10, 100]
|
||||
batch_size_values = [1, 32, 64, 128, 1024]
|
||||
|
||||
sparse_placeholder = array_ops.sparse_placeholder(dtype=dtypes.int64)
|
||||
batch_size_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[])
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensors(sparse_placeholder).repeat(
|
||||
).batch(batch_size_placeholder)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
|
||||
for non_zeros_per_row in non_zeros_per_row_values:
|
||||
|
||||
sparse_value = sparse_tensor.SparseTensorValue(
|
||||
tensor = sparse_tensor.SparseTensor(
|
||||
indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis],
|
||||
values=np.arange(non_zeros_per_row, dtype=np.int64),
|
||||
dense_shape=[1000])
|
||||
|
||||
for batch_size in batch_size_values:
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={
|
||||
sparse_placeholder: sparse_value,
|
||||
batch_size_placeholder: batch_size})
|
||||
# Run five steps to warm up the session caches before taking the
|
||||
# first measurement.
|
||||
for _ in range(5):
|
||||
sess.run(next_element.indices.op)
|
||||
deltas = []
|
||||
for _ in range(100):
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
sess.run(next_element.indices.op)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
|
||||
median_wall_time = np.median(deltas) / 100.0
|
||||
|
||||
self.report_benchmark(
|
||||
iters=10000,
|
||||
wall_time=median_wall_time,
|
||||
name="sparse_num_elements_%d_batch_size_%d" %
|
||||
(non_zeros_per_row, batch_size))
|
||||
dataset = dataset_ops.Dataset.from_tensors(tensor).repeat().batch(
|
||||
batch_size)
|
||||
self.run_and_report_benchmark(
|
||||
dataset,
|
||||
num_elements=100000 // batch_size,
|
||||
iters=1,
|
||||
name="sparse_num_elements_%d_batch_size_%d" % (non_zeros_per_row,
|
||||
batch_size))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
benchmark_base.test.main()
|
||||
|
92
tensorflow/python/data/benchmarks/benchmark_base.py
Normal file
92
tensorflow/python/data/benchmarks/benchmark_base.py
Normal file
@ -0,0 +1,92 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test utilities for tf.data benchmarking functionality."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# TODO(b/119837791): Add eager benchmarks.
|
||||
class DatasetBenchmarkBase(test.Benchmark):
|
||||
"""Base class for dataset benchmarks."""
|
||||
|
||||
def run_benchmark(self, dataset, num_elements, iters=1):
|
||||
"""Benchmarks the dataset.
|
||||
|
||||
Runs the dataset `iters` times. In each iteration, the benchmark measures
|
||||
the time it takes to go through `num_elements` elements of the dataset.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to benchmark.
|
||||
num_elements: Number of dataset elements to iterate through each benchmark
|
||||
iteration.
|
||||
iters: Number of times to repeat the timing.
|
||||
|
||||
Returns:
|
||||
A float, representing the per-element wall time of the dataset in seconds.
|
||||
This is the median time (with respect to `iters`) it takes for the dataset
|
||||
to go through `num_elements` elements, divided by `num_elements.`
|
||||
"""
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
# NOTE: We use `dataset.skip()` to perform the iterations in C++, avoiding
|
||||
# the overhead of multiple `session.run()` calls. Note that this relies on
|
||||
# the underlying implementation of `skip`: if it is optimized in the future,
|
||||
# we will have to change this code.
|
||||
dataset = dataset.skip(num_elements - 1)
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
next_element = nest.flatten(next_element)[0]
|
||||
|
||||
deltas = []
|
||||
for _ in range(iters):
|
||||
with session.Session() as sess:
|
||||
# Run once to warm up the session caches.
|
||||
sess.run(iterator.initializer)
|
||||
sess.run(next_element)
|
||||
|
||||
sess.run(iterator.initializer)
|
||||
start = time.time()
|
||||
sess.run(next_element.op)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
return np.median(deltas) / float(num_elements)
|
||||
|
||||
def run_and_report_benchmark(self,
|
||||
dataset,
|
||||
num_elements,
|
||||
name,
|
||||
iters=5,
|
||||
extras=None):
|
||||
# Measure the per-element wall time.
|
||||
wall_time = self.run_benchmark(dataset, num_elements, iters)
|
||||
|
||||
if extras is None:
|
||||
extras = {}
|
||||
extras["elements_per_second"] = 1 / wall_time
|
||||
extras["num_elements"] = num_elements
|
||||
# 'mode' represents the mechanism used for iterating over dataset elements.
|
||||
name = "%s_mode_cpp" % name
|
||||
self.report_benchmark(
|
||||
wall_time=wall_time, iters=iters, name=name, extras=extras)
|
@ -17,51 +17,26 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# TODO(b/119837791): Add eager benchmarks.
|
||||
class FilterBenchmark(test.Benchmark):
|
||||
class FilterBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""Benchmarks for `tf.data.Dataset.filter()`."""
|
||||
|
||||
def _benchmark(self, predicate, name):
|
||||
with ops.Graph().as_default():
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensors(True).repeat(None).filter(predicate))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
self.run_and_report_benchmark(dataset, num_elements=100000, name=name)
|
||||
|
||||
with session.Session() as sess:
|
||||
for _ in range(5):
|
||||
sess.run(next_element.op)
|
||||
deltas = []
|
||||
for _ in range(100):
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
sess.run(next_element.op)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
|
||||
median_wall_time = np.median(deltas) / 100
|
||||
self.report_benchmark(iters=100, wall_time=median_wall_time, name=name)
|
||||
|
||||
def benchmarkSimpleFunction(self):
|
||||
def benchmark_simple_function(self):
|
||||
self._benchmark(array_ops.identity, "simple_function")
|
||||
|
||||
def benchmarkReturnComponentOptimization(self):
|
||||
def benchmark_return_component_optimization(self):
|
||||
self._benchmark(lambda x: x, "return_component")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
benchmark_base.test.main()
|
||||
|
@ -17,170 +17,70 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# TODO(b/119837791): Add eager benchmarks.
|
||||
class FromTensorSlicesBenchmark(test.Benchmark):
|
||||
class FromTensorSlicesBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""Benchmarks for `tf.data.Dataset.from_tensor_slices()`."""
|
||||
|
||||
def benchmarkSliceRepeatBatch(self):
|
||||
def benchmark_slice_repeat_batch(self):
|
||||
input_size = 10000
|
||||
batch_size = 100
|
||||
num_epochs = 100
|
||||
num_elements = input_size * num_epochs // batch_size
|
||||
|
||||
input_data = np.random.randn(input_size)
|
||||
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(input_data)
|
||||
.repeat(num_epochs + 1).batch(batch_size))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
dataset_ops.Dataset.from_tensor_slices(input_data).repeat(
|
||||
num_epochs).batch(batch_size))
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
# Run one whole epoch to burn in the computation.
|
||||
for _ in range(input_size // batch_size):
|
||||
sess.run(next_element)
|
||||
deltas = []
|
||||
try:
|
||||
while True:
|
||||
start = time.time()
|
||||
sess.run(next_element)
|
||||
deltas.append(time.time() - start)
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
median_wall_time = np.median(deltas)
|
||||
self.report_benchmark(
|
||||
iters=len(deltas),
|
||||
wall_time=median_wall_time,
|
||||
self.run_and_report_benchmark(
|
||||
dataset,
|
||||
num_elements=num_elements,
|
||||
name="slice_repeat_batch_input_%d_batch_%d" % (input_size, batch_size))
|
||||
|
||||
def benchmarkSliceRepeatBatchCallable(self):
|
||||
def benchmark_reshape_slice_repeat(self):
|
||||
input_size = 10000
|
||||
batch_size = 100
|
||||
reshape_dim = [100, 100]
|
||||
num_epochs = 100
|
||||
|
||||
num_elements = num_epochs * reshape_dim[0]
|
||||
|
||||
input_data = np.random.randn(input_size)
|
||||
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(input_data)
|
||||
.repeat(num_epochs + 1).batch(batch_size))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
dataset_ops.Dataset.from_tensor_slices(
|
||||
input_data.reshape(*reshape_dim)).repeat(num_epochs))
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
get_next_element = sess.make_callable(next_element)
|
||||
# Run one whole epoch to burn in the computation.
|
||||
for _ in range(input_size // batch_size):
|
||||
get_next_element()
|
||||
deltas = []
|
||||
try:
|
||||
while True:
|
||||
start = time.time()
|
||||
get_next_element()
|
||||
deltas.append(time.time() - start)
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
self.run_and_report_benchmark(
|
||||
dataset,
|
||||
num_elements=num_elements,
|
||||
name="reshape_slice_repeat_input_%d" % input_size,
|
||||
)
|
||||
|
||||
median_wall_time = np.median(deltas)
|
||||
self.report_benchmark(
|
||||
iters=len(deltas),
|
||||
wall_time=median_wall_time,
|
||||
name="slice_repeat_batch_callable_input_%d_batch_%d" %
|
||||
(input_size, batch_size))
|
||||
|
||||
def benchmarkReshapeSliceRepeatCallable(self):
|
||||
def benchmark_slice_batch_cache_repeat(self):
|
||||
input_size = 10000
|
||||
batch_size = 100
|
||||
num_epochs = 100
|
||||
num_elements = input_size * num_epochs // batch_size
|
||||
|
||||
input_data = np.random.randn(input_size)
|
||||
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(input_data.reshape(100, 100))
|
||||
.repeat(num_epochs + 1))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
dataset_ops.Dataset.from_tensor_slices(input_data).batch(
|
||||
batch_size).cache().repeat(num_epochs))
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
get_next_element = sess.make_callable(next_element)
|
||||
# Run one whole epoch to burn in the computation.
|
||||
for _ in range(input_size // batch_size):
|
||||
get_next_element()
|
||||
deltas = []
|
||||
try:
|
||||
while True:
|
||||
start = time.time()
|
||||
get_next_element()
|
||||
deltas.append(time.time() - start)
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
median_wall_time = np.median(deltas)
|
||||
self.report_benchmark(
|
||||
iters=len(deltas),
|
||||
wall_time=median_wall_time,
|
||||
name="reshape_slice_repeat_callable_input_%d_batch_%d" %
|
||||
(input_size, batch_size))
|
||||
|
||||
def benchmarkSliceBatchCacheRepeatCallable(self):
|
||||
input_size = 10000
|
||||
batch_size = 100
|
||||
num_epochs = 100
|
||||
|
||||
input_data = np.random.randn(input_size)
|
||||
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(input_data).batch(batch_size)
|
||||
.cache().repeat(num_epochs + 1))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
get_next_element = sess.make_callable(next_element)
|
||||
# Run one whole epoch to burn in the computation.
|
||||
for _ in range(input_size // batch_size):
|
||||
get_next_element()
|
||||
deltas = []
|
||||
try:
|
||||
while True:
|
||||
start = time.time()
|
||||
get_next_element()
|
||||
deltas.append(time.time() - start)
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
median_wall_time = np.median(deltas)
|
||||
self.report_benchmark(
|
||||
iters=len(deltas),
|
||||
wall_time=median_wall_time,
|
||||
name="slice_batch_cache_repeat_callable_input_%d_batch_%d" %
|
||||
(input_size, batch_size))
|
||||
self.run_and_report_benchmark(
|
||||
dataset,
|
||||
num_elements=num_elements,
|
||||
name="slice_batch_cache_repeat_input_%d_batch_%d" % (input_size,
|
||||
batch_size))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
benchmark_base.test.main()
|
||||
|
@ -17,114 +17,51 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# TODO(b/119837791): Add eager benchmarks.
|
||||
class MapBenchmark(test.Benchmark):
|
||||
"""Bechmarks for `tf.data.Dataset.map()`."""
|
||||
class MapBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""Benchmarks for `tf.data.Dataset.map()`."""
|
||||
|
||||
def benchmarkChainOfMaps(self):
|
||||
chain_lengths = [0, 1, 2, 5, 10, 20, 50]
|
||||
for chain_length in chain_lengths:
|
||||
for mode in ["general", "single-threaded", "short-circuit"]:
|
||||
if mode == "general":
|
||||
map_fn = lambda x: x + 1
|
||||
use_inter_op_parallelism = True
|
||||
benchmark_label = ""
|
||||
if mode == "single-threaded":
|
||||
map_fn = lambda x: x + 1
|
||||
use_inter_op_parallelism = False
|
||||
benchmark_label = "_single_threaded"
|
||||
if mode == "short-circuit":
|
||||
map_fn = lambda x: x
|
||||
use_inter_op_parallelism = True # should not have any significance
|
||||
benchmark_label = "_short_circuit"
|
||||
def benchmark_chain_of_maps(self):
|
||||
|
||||
with ops.Graph().as_default():
|
||||
def benchmark_helper(chain_length, map_fn, use_inter_op_parallelism, label):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
|
||||
for _ in range(chain_length):
|
||||
dataset = dataset_ops.MapDataset(
|
||||
dataset, map_fn, use_inter_op_parallelism=use_inter_op_parallelism)
|
||||
self.run_and_report_benchmark(
|
||||
dataset,
|
||||
map_fn,
|
||||
use_inter_op_parallelism=use_inter_op_parallelism)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
num_elements=10000,
|
||||
name="chain_length_%d%s" % (chain_length, label))
|
||||
|
||||
with session.Session() as sess:
|
||||
for _ in range(5):
|
||||
sess.run(next_element.op)
|
||||
deltas = []
|
||||
for _ in range(100):
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
sess.run(next_element.op)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
chain_lengths = [0, 1, 2, 5, 10, 20, 50]
|
||||
for chain_length in chain_lengths:
|
||||
benchmark_helper(chain_length, lambda x: x + 1, True, "")
|
||||
benchmark_helper(chain_length, lambda x: x + 1, False, "_single_threaded")
|
||||
benchmark_helper(chain_length, lambda x: x, True, "_short_circuit")
|
||||
|
||||
median_wall_time = np.median(deltas) / 100
|
||||
self.report_benchmark(
|
||||
iters=1000,
|
||||
wall_time=median_wall_time,
|
||||
name="chain_length_%d%s" % (chain_length, benchmark_label))
|
||||
|
||||
def benchmarkMapFanOut(self):
|
||||
def benchmark_map_fan_out(self):
|
||||
fan_outs = [1, 2, 5, 10, 20, 50, 100]
|
||||
for fan_out in fan_outs:
|
||||
for mode in ["general", "single-threaded", "short-circuit"]:
|
||||
if mode == "general":
|
||||
map_fn = lambda *xs: [x + 1 for x in xs]
|
||||
use_inter_op_parallelism = True
|
||||
benchmark_label = ""
|
||||
if mode == "single-threaded":
|
||||
map_fn = lambda *xs: [x + 1 for x in xs]
|
||||
use_inter_op_parallelism = False
|
||||
benchmark_label = "_single_threaded"
|
||||
if mode == "short-circuit":
|
||||
map_fn = lambda *xs: xs
|
||||
use_inter_op_parallelism = True # should not have any significance
|
||||
benchmark_label = "_short_circuit"
|
||||
|
||||
with ops.Graph().as_default():
|
||||
def benchmark_helper(fan_out, map_fn, use_inter_op_parallelism, label):
|
||||
dataset = dataset_ops.Dataset.from_tensors(
|
||||
tuple(0 for _ in range(fan_out))).repeat(None)
|
||||
dataset = dataset_ops.MapDataset(
|
||||
dataset, map_fn, use_inter_op_parallelism=use_inter_op_parallelism)
|
||||
self.run_and_report_benchmark(
|
||||
dataset,
|
||||
map_fn,
|
||||
use_inter_op_parallelism=use_inter_op_parallelism)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
num_elements=10000,
|
||||
name="fan_out_%d%s" % (fan_out, label))
|
||||
|
||||
with session.Session() as sess:
|
||||
for _ in range(5):
|
||||
sess.run(next_element[0].op)
|
||||
deltas = []
|
||||
for _ in range(100):
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
sess.run(next_element[0].op)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
|
||||
median_wall_time = np.median(deltas) / 100
|
||||
self.report_benchmark(
|
||||
iters=1000,
|
||||
wall_time=median_wall_time,
|
||||
name="fan_out_%d%s" % (fan_out, benchmark_label))
|
||||
for fan_out in fan_outs:
|
||||
benchmark_helper(fan_out, lambda *xs: [x + 1 for x in xs], True, "")
|
||||
benchmark_helper(fan_out, lambda *xs: [x + 1 for x in xs], False,
|
||||
"_single_threaded")
|
||||
benchmark_helper(fan_out, lambda *xs: xs, True, "_short_circuit")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
benchmark_base.test.main()
|
||||
|
@ -17,54 +17,26 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
_NUMPY_RANDOM_SEED = 42
|
||||
|
||||
|
||||
class RangeBenchmark(test.Benchmark):
|
||||
class RangeBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""Benchmarks for `tf.data.Dataset.range()`."""
|
||||
|
||||
def _benchmarkRangeHelper(self, modeling_enabled):
|
||||
def benchmark_range(self):
|
||||
for modeling_enabled in [False, True]:
|
||||
num_elements = 10000000 if modeling_enabled else 50000000
|
||||
|
||||
# Use `Dataset.skip()` and `Dataset.take()` to perform the iteration in
|
||||
# C++, and focus on the minimal overheads (excluding Python invocation
|
||||
# costs).
|
||||
dataset = dataset_ops.Dataset.range(num_elements).skip(
|
||||
num_elements - 1).take(1)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_autotune = modeling_enabled
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset_ops.Dataset.range(num_elements)
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
# Run once to warm up the session caches.
|
||||
sess.run(iterator.initializer)
|
||||
sess.run(next_element)
|
||||
|
||||
# Run once for timing.
|
||||
sess.run(iterator.initializer)
|
||||
start = time.time()
|
||||
sess.run(next_element)
|
||||
end = time.time()
|
||||
|
||||
time_per_element = (end - start) / num_elements
|
||||
self.report_benchmark(
|
||||
iters=num_elements,
|
||||
wall_time=time_per_element,
|
||||
self.run_and_report_benchmark(
|
||||
dataset,
|
||||
num_elements=num_elements,
|
||||
name="modeling_%s" % ("on" if modeling_enabled else "off"))
|
||||
|
||||
def benchmarkRange(self):
|
||||
for modeling_enabled in [False, True]:
|
||||
self._benchmarkRangeHelper(modeling_enabled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
benchmark_base.test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user