Merge pull request #47015 from kvignesh1420:eager-exp-bechmarks-p2
PiperOrigin-RevId: 356831159 Change-Id: I4503a76dec610fb0e5817110fa54f1645cf84ecb
This commit is contained in:
commit
cc2635d67c
@ -31,12 +31,116 @@ from tensorflow.python.platform import test
|
||||
class DatasetBenchmarkBase(test.Benchmark):
|
||||
"""Base class for dataset benchmarks."""
|
||||
|
||||
def _run_eager_benchmark(self, iterable, iters, warmup):
|
||||
"""Benchmark the iterable in eager mode.
|
||||
|
||||
Runs the iterable `iters` times. In each iteration, the benchmark measures
|
||||
the time it takes to go execute the iterable.
|
||||
|
||||
Args:
|
||||
iterable: The tf op or tf.data Dataset to benchmark.
|
||||
iters: Number of times to repeat the timing.
|
||||
warmup: If true, warms up the session caches by running an untimed run.
|
||||
|
||||
Returns:
|
||||
A float, representing the median time (with respect to `iters`)
|
||||
it takes for the iterable to be executed `iters` num of times.
|
||||
"""
|
||||
|
||||
deltas = []
|
||||
if not context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"Eager mode benchmarking is not supported in graph mode.")
|
||||
|
||||
for _ in range(iters):
|
||||
if warmup:
|
||||
iterator = iter(iterable)
|
||||
next(iterator)
|
||||
|
||||
iterator = iter(iterable)
|
||||
start = time.time()
|
||||
next(iterator)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
return np.median(deltas)
|
||||
|
||||
def _run_graph_benchmark(self,
|
||||
iterable,
|
||||
iters,
|
||||
warmup,
|
||||
session_config,
|
||||
initializer=None):
|
||||
"""Benchmarks the iterable in graph mode.
|
||||
|
||||
Runs the iterable `iters` times. In each iteration, the benchmark measures
|
||||
the time it takes to go execute the iterable.
|
||||
|
||||
Args:
|
||||
iterable: The tf op or tf.data Dataset to benchmark.
|
||||
iters: Number of times to repeat the timing.
|
||||
warmup: If true, warms up the session caches by running an untimed run.
|
||||
session_config: A ConfigProto protocol buffer with configuration options
|
||||
for the session. Applicable only for benchmarking in graph mode.
|
||||
initializer: The initializer op required to initialize the iterable.
|
||||
|
||||
Returns:
|
||||
A float, representing the median time (with respect to `iters`)
|
||||
it takes for the iterable to be executed `iters` num of times.
|
||||
"""
|
||||
|
||||
deltas = []
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"Graph mode benchmarking is not supported in eager mode.")
|
||||
|
||||
for _ in range(iters):
|
||||
with session.Session(config=session_config) as sess:
|
||||
if warmup:
|
||||
# Run once to warm up the session caches.
|
||||
if initializer:
|
||||
sess.run(initializer)
|
||||
sess.run(iterable)
|
||||
|
||||
if initializer:
|
||||
sess.run(initializer)
|
||||
start = time.time()
|
||||
sess.run(iterable)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
return np.median(deltas)
|
||||
|
||||
def run_op_benchmark(self, op, iters=1, warmup=True, session_config=None):
|
||||
"""Benchmarks the op.
|
||||
|
||||
Runs the op `iters` times. In each iteration, the benchmark measures
|
||||
the time it takes to go execute the op.
|
||||
|
||||
Args:
|
||||
op: The tf op to benchmark.
|
||||
iters: Number of times to repeat the timing.
|
||||
warmup: If true, warms up the session caches by running an untimed run.
|
||||
session_config: A ConfigProto protocol buffer with configuration options
|
||||
for the session. Applicable only for benchmarking in graph mode.
|
||||
|
||||
Returns:
|
||||
A float, representing the per-execution wall time of the op in seconds.
|
||||
This is the median time (with respect to `iters`) it takes for the op
|
||||
to be executed `iters` num of times.
|
||||
"""
|
||||
|
||||
if context.executing_eagerly():
|
||||
return self._run_eager_benchmark(iterable=op, iters=iters, warmup=warmup)
|
||||
|
||||
return self._run_graph_benchmark(
|
||||
iterable=op, iters=iters, warmup=warmup, session_config=session_config)
|
||||
|
||||
def run_benchmark(self,
|
||||
dataset,
|
||||
num_elements,
|
||||
iters=1,
|
||||
warmup=True,
|
||||
apply_default_optimizations=False):
|
||||
apply_default_optimizations=False,
|
||||
session_config=None):
|
||||
"""Benchmarks the dataset.
|
||||
|
||||
Runs the dataset `iters` times. In each iteration, the benchmark measures
|
||||
@ -50,6 +154,8 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
warmup: If true, warms up the session caches by running an untimed run.
|
||||
apply_default_optimizations: Determines whether default optimizations
|
||||
should be applied.
|
||||
session_config: A ConfigProto protocol buffer with configuration options
|
||||
for the session. Applicable only for benchmarking in graph mode.
|
||||
|
||||
Returns:
|
||||
A float, representing the per-element wall time of the dataset in seconds.
|
||||
@ -70,38 +176,22 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
# to execute upstream computation. If it is optimized in the future,
|
||||
# we will have to change this code.
|
||||
dataset = dataset.skip(num_elements - 1)
|
||||
deltas = []
|
||||
|
||||
if context.executing_eagerly():
|
||||
for _ in range(iters):
|
||||
if warmup:
|
||||
iterator = iter(dataset)
|
||||
next(iterator)
|
||||
|
||||
iterator = iter(dataset)
|
||||
start = time.time()
|
||||
next(iterator)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
return np.median(deltas) / float(num_elements)
|
||||
median_duration = self._run_eager_benchmark(
|
||||
iterable=dataset, iters=iters, warmup=warmup)
|
||||
return median_duration / float(num_elements)
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
next_element = nest.flatten(next_element)[0]
|
||||
|
||||
for _ in range(iters):
|
||||
with session.Session() as sess:
|
||||
if warmup:
|
||||
# Run once to warm up the session caches.
|
||||
sess.run(iterator.initializer)
|
||||
sess.run(next_element.op)
|
||||
|
||||
sess.run(iterator.initializer)
|
||||
start = time.time()
|
||||
sess.run(next_element.op)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
return np.median(deltas) / float(num_elements)
|
||||
op = nest.flatten(next_element)[0].op
|
||||
median_duration = self._run_graph_benchmark(
|
||||
iterable=op,
|
||||
iters=iters,
|
||||
warmup=warmup,
|
||||
session_config=session_config,
|
||||
initializer=iterator.initializer)
|
||||
return median_duration / float(num_elements)
|
||||
|
||||
def run_and_report_benchmark(self,
|
||||
dataset,
|
||||
@ -110,7 +200,8 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
iters=5,
|
||||
extras=None,
|
||||
warmup=True,
|
||||
apply_default_optimizations=False):
|
||||
apply_default_optimizations=False,
|
||||
session_config=None):
|
||||
"""Benchmarks the dataset and reports the stats.
|
||||
|
||||
Runs the dataset `iters` times. In each iteration, the benchmark measures
|
||||
@ -127,6 +218,8 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
warmup: If true, warms up the session caches by running an untimed run.
|
||||
apply_default_optimizations: Determines whether default optimizations
|
||||
should be applied.
|
||||
session_config: A ConfigProto protocol buffer with configuration options
|
||||
for the session. Applicable only for benchmarking in graph mode.
|
||||
|
||||
Returns:
|
||||
A float, representing the per-element wall time of the dataset in seconds.
|
||||
@ -138,7 +231,8 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
num_elements=num_elements,
|
||||
iters=iters,
|
||||
warmup=warmup,
|
||||
apply_default_optimizations=apply_default_optimizations)
|
||||
apply_default_optimizations=apply_default_optimizations,
|
||||
session_config=session_config)
|
||||
if context.executing_eagerly():
|
||||
name = "{}.eager".format(name)
|
||||
else:
|
||||
|
@ -17,18 +17,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class AutotuneBenchmark(test.Benchmark):
|
||||
class AutotuneBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""Benchmarks for autotuning performance knobs."""
|
||||
|
||||
def _run_benchmark(self, dataset, autotune, autotune_buffers,
|
||||
@ -38,30 +35,17 @@ class AutotuneBenchmark(test.Benchmark):
|
||||
options.experimental_optimization.autotune = autotune
|
||||
options.experimental_optimization.autotune_buffers = autotune_buffers
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
get_next = iterator.get_next()
|
||||
# Run the op directly to avoid copying the tensor to python.
|
||||
get_next_op = nest.flatten(get_next)[0].op
|
||||
|
||||
deltas = []
|
||||
with session.Session() as sess:
|
||||
for _ in range(5):
|
||||
sess.run(get_next_op)
|
||||
for _ in range(benchmark_iters):
|
||||
start = time.time()
|
||||
sess.run(get_next_op)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
|
||||
autotune_string = "_autotune_{}".format(
|
||||
"parallelism_and_buffer_sizes"
|
||||
if autotune_buffers else "parallelism_only")
|
||||
|
||||
self.report_benchmark(
|
||||
wall_time = self.run_and_report_benchmark(
|
||||
dataset=dataset,
|
||||
num_elements=1,
|
||||
warmup=True,
|
||||
iters=benchmark_iters,
|
||||
wall_time=np.median(deltas),
|
||||
name=benchmark_label + (autotune_string if autotune else ""))
|
||||
return np.median(deltas)
|
||||
return wall_time
|
||||
|
||||
def benchmark_batch(self):
|
||||
a = self._benchmark_batch(autotune=False)
|
||||
@ -101,9 +85,9 @@ class AutotuneBenchmark(test.Benchmark):
|
||||
dataset = dataset.map(
|
||||
math_ops.matmul, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
return self._run_benchmark(
|
||||
dataset,
|
||||
autotune,
|
||||
autotune_buffers,
|
||||
dataset=dataset,
|
||||
autotune=autotune,
|
||||
autotune_buffers=autotune_buffers,
|
||||
benchmark_iters=10000,
|
||||
benchmark_label="map")
|
||||
|
||||
@ -124,9 +108,9 @@ class AutotuneBenchmark(test.Benchmark):
|
||||
math_ops.matmul, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
dataset = dataset.batch(batch_size=batch_size)
|
||||
return self._run_benchmark(
|
||||
dataset,
|
||||
autotune,
|
||||
autotune_buffers,
|
||||
dataset=dataset,
|
||||
autotune=autotune,
|
||||
autotune_buffers=autotune_buffers,
|
||||
benchmark_iters=1000,
|
||||
benchmark_label="map_and_batch")
|
||||
|
||||
@ -148,9 +132,9 @@ class AutotuneBenchmark(test.Benchmark):
|
||||
cycle_length=10,
|
||||
num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
return self._run_benchmark(
|
||||
dataset,
|
||||
autotune,
|
||||
autotune_buffers,
|
||||
dataset=dataset,
|
||||
autotune=autotune,
|
||||
autotune_buffers=autotune_buffers,
|
||||
benchmark_iters=10000,
|
||||
benchmark_label="interleave")
|
||||
|
||||
@ -196,9 +180,9 @@ class AutotuneBenchmark(test.Benchmark):
|
||||
dataset = dataset_ops.Dataset.zip((dataset, dataset_c))
|
||||
dataset = dataset.map(f2, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
return self._run_benchmark(
|
||||
dataset,
|
||||
autotune,
|
||||
autotune_buffers,
|
||||
dataset=dataset,
|
||||
autotune=autotune,
|
||||
autotune_buffers=autotune_buffers,
|
||||
benchmark_iters=10000,
|
||||
benchmark_label="map_and_interleave")
|
||||
|
||||
@ -244,12 +228,12 @@ class AutotuneBenchmark(test.Benchmark):
|
||||
dataset_c = dataset_c.batch(batch_size=batch_size)
|
||||
dataset = dataset_ops.Dataset.zip((dataset, dataset_c))
|
||||
return self._run_benchmark(
|
||||
dataset,
|
||||
autotune,
|
||||
autotune_buffers,
|
||||
dataset=dataset,
|
||||
autotune=autotune,
|
||||
autotune_buffers=autotune_buffers,
|
||||
benchmark_iters=1000,
|
||||
benchmark_label="map_batch_and_interleave")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
benchmark_base.test.main()
|
||||
|
@ -17,18 +17,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.experimental.ops import optimization
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# TODO(b/119837791): Add eager benchmarks too.
|
||||
class ChooseFastestBenchmark(test.Benchmark):
|
||||
class ChooseFastestBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""Benchmarks for static optimizations."""
|
||||
|
||||
def benchmark_choose_fastest(self):
|
||||
@ -42,9 +36,9 @@ class ChooseFastestBenchmark(test.Benchmark):
|
||||
|
||||
merge_dataset = optimization._ChooseFastestDataset( # pylint: disable=protected-access
|
||||
[batch_map_dataset, map_batch_dataset])
|
||||
self._benchmark(map_batch_dataset, "map_batch_dataset")
|
||||
self._benchmark(batch_map_dataset, "batch_map_dataset")
|
||||
self._benchmark(merge_dataset, "merge_dataset")
|
||||
self._benchmark(dataset=map_batch_dataset, name="map_batch_dataset")
|
||||
self._benchmark(dataset=batch_map_dataset, name="batch_map_dataset")
|
||||
self._benchmark(dataset=merge_dataset, name="merge_dataset")
|
||||
|
||||
def benchmark_choose_fastest_first_n_iterations(self):
|
||||
|
||||
@ -58,48 +52,24 @@ class ChooseFastestBenchmark(test.Benchmark):
|
||||
merge_dataset = optimization._ChooseFastestDataset( # pylint: disable=protected-access
|
||||
[batch_map_dataset, map_batch_dataset])
|
||||
|
||||
self._benchmark_first_n(map_batch_dataset, "map_batch_dataset")
|
||||
self._benchmark_first_n(batch_map_dataset, "batch_map_dataset")
|
||||
self._benchmark_first_n(merge_dataset, "merge_dataset")
|
||||
self._benchmark_first_n(dataset=map_batch_dataset, name="map_batch_dataset")
|
||||
self._benchmark_first_n(dataset=batch_map_dataset, name="batch_map_dataset")
|
||||
self._benchmark_first_n(dataset=merge_dataset, name="merge_dataset")
|
||||
|
||||
def _benchmark_first_n(self, dataset, name):
|
||||
n = 10 # The default num_experiments for ChooseFastestDataset
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
|
||||
deltas = []
|
||||
for _ in range(100):
|
||||
with session.Session() as sess:
|
||||
start = time.time()
|
||||
for _ in range(n):
|
||||
sess.run(next_element.op)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
median_wall_time = np.median(deltas) / n
|
||||
self.report_benchmark(
|
||||
iters=n, wall_time=median_wall_time, name=name + "_first_%d" % n)
|
||||
self.run_and_report_benchmark(
|
||||
dataset=dataset,
|
||||
num_elements=n,
|
||||
iters=100,
|
||||
warmup=True,
|
||||
name=name + "_first_%d" % n)
|
||||
|
||||
def _benchmark(self, dataset, name):
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
# Run 10 steps to warm up the session caches before taking the first
|
||||
# measurement. Additionally, 10 is the default num_experiments for
|
||||
# ChooseFastestDataset.
|
||||
for _ in range(10):
|
||||
sess.run(next_element.op)
|
||||
deltas = []
|
||||
for _ in range(50):
|
||||
start = time.time()
|
||||
for _ in range(50):
|
||||
sess.run(next_element.op)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
|
||||
median_wall_time = np.median(deltas) / 100
|
||||
self.report_benchmark(iters=100, wall_time=median_wall_time, name=name)
|
||||
self.run_and_report_benchmark(
|
||||
dataset=dataset, num_elements=100, iters=100, warmup=True, name=name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
benchmark_base.test.main()
|
||||
|
@ -19,25 +19,21 @@ from __future__ import print_function
|
||||
|
||||
import hashlib
|
||||
import itertools
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
_NUMPY_RANDOM_SEED = 42
|
||||
|
||||
|
||||
class MapAndBatchBenchmark(test.Benchmark):
|
||||
class MapAndBatchBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""Benchmarks for `tf.data.experimental.map_and_batch()`."""
|
||||
|
||||
def benchmark_map_and_batch(self):
|
||||
@ -45,53 +41,23 @@ class MapAndBatchBenchmark(test.Benchmark):
|
||||
shapes = [(), (10,), (10, 10), (10, 10, 10), (224, 224, 3)]
|
||||
batch_size_values = [1, 32, 64, 128, 1024]
|
||||
|
||||
shape_placeholder = array_ops.placeholder(dtypes.int64, shape=[None])
|
||||
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
dataset = dataset_ops.Dataset.range(1000000000)
|
||||
|
||||
dense_value = random_ops.random_normal(shape=shape_placeholder)
|
||||
|
||||
dataset = dataset.apply(batching.map_and_batch(
|
||||
lambda _: dense_value, batch_size_placeholder))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
|
||||
for shape in shapes:
|
||||
for batch_size in batch_size_values:
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={
|
||||
shape_placeholder: shape, batch_size_placeholder: batch_size})
|
||||
dataset = dataset_ops.Dataset.range(1000000000)
|
||||
dense_value = random_ops.random_normal(shape=shape)
|
||||
|
||||
# Use a C++ callable to minimize the Python overhead in the benchmark.
|
||||
callable_opts = config_pb2.CallableOptions()
|
||||
callable_opts.target.append(next_element.op.name)
|
||||
op_callable = sess._make_callable_from_options(callable_opts) # pylint: disable=protected-access
|
||||
dataset = dataset.apply(
|
||||
batching.map_and_batch(lambda _: dense_value, batch_size))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
|
||||
# Run five steps to warm up the session caches before taking the
|
||||
# first measurement.
|
||||
for _ in range(5):
|
||||
op_callable()
|
||||
deltas = []
|
||||
overall_start = time.time()
|
||||
# Run at least five repetitions and for at least five seconds.
|
||||
while len(deltas) < 5 or time.time() - overall_start < 5.0:
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
op_callable()
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
del op_callable
|
||||
|
||||
median_wall_time = np.median(deltas) / 100.0
|
||||
iters = len(deltas) * 100
|
||||
|
||||
self.report_benchmark(
|
||||
iters=iters, wall_time=median_wall_time,
|
||||
self.run_and_report_benchmark(
|
||||
dataset=dataset,
|
||||
num_elements=batch_size,
|
||||
iters=100,
|
||||
warmup=True,
|
||||
name="num_elements_%d_batch_size_%d" % (np.prod(shape), batch_size))
|
||||
|
||||
def benchmark_map_and_batch_chaining_versus_fusing(self):
|
||||
@ -143,24 +109,15 @@ class MapAndBatchBenchmark(test.Benchmark):
|
||||
(element_size * batch_size) // min(num_calls, inter_op))
|
||||
# By default the chained map().batch() calls will not be fused.
|
||||
chained_dataset = make_dataset(element_size, num_calls, batch_size)
|
||||
chained_iterator = dataset_ops.make_one_shot_iterator(chained_dataset)
|
||||
chained_get_next = chained_iterator.get_next()
|
||||
chained_deltas = []
|
||||
with session.Session(
|
||||
config=config_pb2.ConfigProto(
|
||||
inter_op_parallelism_threads=inter_op,
|
||||
use_per_session_threads=True)) as sess:
|
||||
for _ in range(5):
|
||||
sess.run(chained_get_next.op)
|
||||
for _ in range(num_iters):
|
||||
start = time.time()
|
||||
sess.run(chained_get_next.op)
|
||||
end = time.time()
|
||||
chained_deltas.append(end - start)
|
||||
session_config = config_pb2.ConfigProto(
|
||||
inter_op_parallelism_threads=inter_op, use_per_session_threads=True)
|
||||
|
||||
self.report_benchmark(
|
||||
self.run_and_report_benchmark(
|
||||
dataset=chained_dataset,
|
||||
iters=num_iters,
|
||||
wall_time=np.median(chained_deltas),
|
||||
num_elements=batch_size,
|
||||
warmup=True,
|
||||
session_config=session_config,
|
||||
name=name("chained", label, num_calls, inter_op, element_size,
|
||||
batch_size))
|
||||
|
||||
@ -168,30 +125,16 @@ class MapAndBatchBenchmark(test.Benchmark):
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.map_and_batch_fusion = True
|
||||
fused_dataset = chained_dataset.with_options(options)
|
||||
fused_iterator = dataset_ops.make_one_shot_iterator(fused_dataset)
|
||||
fused_get_next = fused_iterator.get_next()
|
||||
fused_deltas = []
|
||||
with session.Session(
|
||||
config=config_pb2.ConfigProto(
|
||||
inter_op_parallelism_threads=inter_op,
|
||||
use_per_session_threads=True)) as sess:
|
||||
|
||||
for _ in range(5):
|
||||
sess.run(fused_get_next.op)
|
||||
for _ in range(num_iters):
|
||||
start = time.time()
|
||||
sess.run(fused_get_next.op)
|
||||
end = time.time()
|
||||
fused_deltas.append(end - start)
|
||||
|
||||
self.report_benchmark(
|
||||
self.run_and_report_benchmark(
|
||||
dataset=fused_dataset,
|
||||
iters=num_iters,
|
||||
wall_time=np.median(fused_deltas),
|
||||
num_elements=batch_size,
|
||||
warmup=True,
|
||||
session_config=session_config,
|
||||
name=name("fused", label, num_calls, inter_op, element_size,
|
||||
batch_size))
|
||||
|
||||
print()
|
||||
|
||||
np.random.seed(_NUMPY_RANDOM_SEED)
|
||||
benchmark("Sequential element size evaluation", seq_elem_size_series)
|
||||
benchmark("Sequential batch size evaluation", seq_batch_size_series)
|
||||
@ -202,4 +145,4 @@ class MapAndBatchBenchmark(test.Benchmark):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
benchmark_base.test.main()
|
||||
|
@ -17,37 +17,30 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.experimental.ops import map_defun
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# TODO(b/119837791): Add eager benchmarks too.
|
||||
class MapDefunBenchmark(test.Benchmark):
|
||||
class MapDefunBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""Benchmarks for MapDefunOp."""
|
||||
|
||||
def _run(self, op, name=None, num_iters=3000):
|
||||
with session.Session() as sess:
|
||||
for _ in range(5):
|
||||
sess.run(op)
|
||||
start = time.time()
|
||||
for _ in range(num_iters):
|
||||
sess.run(op)
|
||||
end = time.time()
|
||||
mean_us = (end - start) * 1e6 / num_iters
|
||||
self.report_benchmark(
|
||||
name=name,
|
||||
iters=num_iters,
|
||||
wall_time=mean_us,
|
||||
extras={"examples_per_sec": num_iters / (end - start)})
|
||||
|
||||
wall_time = self.run_op_benchmark(op=op, iters=num_iters, warmup=True)
|
||||
zero_division_delta = 1e-100
|
||||
wall_time = wall_time + zero_division_delta
|
||||
self.report_benchmark(
|
||||
name=name,
|
||||
iters=num_iters,
|
||||
wall_time=wall_time,
|
||||
extras={"examples_per_sec": 1 / float(wall_time)})
|
||||
|
||||
def benchmark_defun_vs_map_fn(self):
|
||||
"""Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
|
||||
@ -59,17 +52,21 @@ class MapDefunBenchmark(test.Benchmark):
|
||||
def fn(x):
|
||||
return array_ops.identity(x)
|
||||
|
||||
base = math_ops.range(100)
|
||||
base = math_ops.range(10000)
|
||||
for input_size in [10, 100, 1000, 10000]:
|
||||
num_iters = 100000 // input_size
|
||||
num_iters = 10000 // input_size
|
||||
map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
|
||||
map_fn_op = map_fn.map_fn(fn, base)
|
||||
|
||||
self._run(
|
||||
map_defun_op, "with_defun_size_%d" % input_size, num_iters=num_iters)
|
||||
op=map_defun_op,
|
||||
name="with_defun_size_%d" % input_size,
|
||||
num_iters=num_iters)
|
||||
self._run(
|
||||
map_fn_op, "without_defun_size_%d" % input_size, num_iters=num_iters)
|
||||
op=map_fn_op,
|
||||
name="without_defun_size_%d" % input_size,
|
||||
num_iters=num_iters)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
benchmark_base.test.main()
|
||||
|
@ -17,16 +17,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.experimental.ops import stats_aggregator
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
NON_PARALLEL = "non_parallel"
|
||||
EXPERIMENTAL_PARALLEL = "experimental_parallel"
|
||||
@ -65,7 +61,7 @@ def _make_fake_dataset_fn(initial_delay_us, remainder_delay_us):
|
||||
return fake_dataset_fn
|
||||
|
||||
|
||||
class ParallelInterleaveBenchmark(test.Benchmark):
|
||||
class ParallelInterleaveBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""Benchmarks for `tf.data.experimental.parallel_interleave()`."""
|
||||
|
||||
def apply_interleave(self, interleave_version, dataset, interleave_fn,
|
||||
@ -94,8 +90,12 @@ class ParallelInterleaveBenchmark(test.Benchmark):
|
||||
num_parallel_calls=None):
|
||||
dataset = dataset_ops.Dataset.range(1).repeat()
|
||||
interleave_fn = _make_fake_dataset_fn(initial_delay, remainder_delay)
|
||||
return self.apply_interleave(interleave_version, dataset, interleave_fn,
|
||||
cycle_length, num_parallel_calls)
|
||||
return self.apply_interleave(
|
||||
interleave_version=interleave_version,
|
||||
dataset=dataset,
|
||||
interleave_fn=interleave_fn,
|
||||
cycle_length=cycle_length,
|
||||
num_parallel_calls=num_parallel_calls)
|
||||
|
||||
def _benchmark(self,
|
||||
interleave_version,
|
||||
@ -107,26 +107,29 @@ class ParallelInterleaveBenchmark(test.Benchmark):
|
||||
num_parallel_calls=None,
|
||||
attach_stats_aggregator=False,
|
||||
name=None):
|
||||
ds = self.make_dataset(interleave_version, initial_delay_us,
|
||||
remainder_delay_us, cycle_length, num_parallel_calls)
|
||||
dataset = self.make_dataset(
|
||||
interleave_version=interleave_version,
|
||||
initial_delay=initial_delay_us,
|
||||
remainder_delay=remainder_delay_us,
|
||||
cycle_length=cycle_length,
|
||||
num_parallel_calls=num_parallel_calls)
|
||||
if attach_stats_aggregator:
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
opts = dataset_ops.Options()
|
||||
opts.experimental_stats.aggregator = aggregator
|
||||
ds = ds.with_options(opts)
|
||||
dataset = dataset.with_options(opts)
|
||||
|
||||
ds = ds.skip(num_elements)
|
||||
deltas = []
|
||||
for _ in range(iters):
|
||||
start = time.time()
|
||||
next(iter(ds))
|
||||
deltas.append(time.time() - start)
|
||||
self.report_benchmark(iters=iters, wall_time=np.median(deltas), name=name)
|
||||
self.run_and_report_benchmark(
|
||||
dataset=dataset,
|
||||
num_elements=num_elements,
|
||||
iters=iters,
|
||||
warmup=True,
|
||||
name=name)
|
||||
|
||||
def benchmark_remote_file_simulation(self):
|
||||
for version in [EXPERIMENTAL_PARALLEL, CORE_PARALLEL]:
|
||||
self._benchmark(
|
||||
version,
|
||||
interleave_version=version,
|
||||
initial_delay_us=100 * 1000,
|
||||
remainder_delay_us=1000,
|
||||
num_elements=5000,
|
||||
@ -135,14 +138,16 @@ class ParallelInterleaveBenchmark(test.Benchmark):
|
||||
def benchmark_fast_input(self):
|
||||
for version in [EXPERIMENTAL_PARALLEL, CORE_PARALLEL]:
|
||||
self._benchmark(
|
||||
version, num_elements=200000, name="fast_input_" + version)
|
||||
interleave_version=version,
|
||||
num_elements=200000,
|
||||
name="fast_input_" + version)
|
||||
|
||||
# Measure the overhead of parallel interleaves compared to non-parallel
|
||||
# interleave.
|
||||
def benchmark_single_cycle(self):
|
||||
for version in [NON_PARALLEL, EXPERIMENTAL_PARALLEL, CORE_PARALLEL]:
|
||||
self._benchmark(
|
||||
version,
|
||||
interleave_version=version,
|
||||
cycle_length=1,
|
||||
num_elements=200000,
|
||||
name="single_cycle_" + version)
|
||||
@ -151,7 +156,7 @@ class ParallelInterleaveBenchmark(test.Benchmark):
|
||||
# cannot be compared here because it sets num_parallel_calls = cycle_length.
|
||||
def benchmark_single_parallel_call(self):
|
||||
self._benchmark(
|
||||
CORE_PARALLEL,
|
||||
interleave_version=CORE_PARALLEL,
|
||||
num_elements=200000,
|
||||
num_parallel_calls=1,
|
||||
name="single_parallel_call_" + CORE_PARALLEL)
|
||||
@ -159,14 +164,14 @@ class ParallelInterleaveBenchmark(test.Benchmark):
|
||||
def benchmark_long_cycle(self):
|
||||
for version in [EXPERIMENTAL_PARALLEL, CORE_PARALLEL]:
|
||||
self._benchmark(
|
||||
version,
|
||||
interleave_version=version,
|
||||
cycle_length=1000,
|
||||
num_elements=100000,
|
||||
name="long_cycle_" + version)
|
||||
|
||||
def benchmark_stats(self):
|
||||
self._benchmark(
|
||||
CORE_PARALLEL,
|
||||
interleave_version=CORE_PARALLEL,
|
||||
cycle_length=50,
|
||||
num_elements=1000,
|
||||
name="stats",
|
||||
@ -174,5 +179,4 @@ class ParallelInterleaveBenchmark(test.Benchmark):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
benchmark_base.test.main()
|
||||
|
@ -17,43 +17,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.experimental.ops import resampling
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _time_resampling(data_np, target_dist, init_dist, num_to_sample): # pylint: disable=missing-docstring
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat()
|
||||
|
||||
# Reshape distribution via rejection sampling.
|
||||
dataset = dataset.apply(
|
||||
resampling.rejection_resample(
|
||||
class_func=lambda x: x,
|
||||
target_dist=target_dist,
|
||||
initial_dist=init_dist,
|
||||
seed=142))
|
||||
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
get_next = dataset_ops.make_one_shot_iterator(dataset).get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
start_time = time.time()
|
||||
for _ in xrange(num_to_sample):
|
||||
sess.run(get_next)
|
||||
end_time = time.time()
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
|
||||
class RejectionResampleBenchmark(test.Benchmark):
|
||||
class RejectionResampleBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""Benchmarks for `tf.data.experimental.rejection_resample()`."""
|
||||
|
||||
def benchmark_resample_performance(self):
|
||||
@ -63,12 +34,28 @@ class RejectionResampleBenchmark(test.Benchmark):
|
||||
# We don't need many samples to test a dirac-delta target distribution
|
||||
num_samples = 1000
|
||||
data_np = np.random.choice(num_classes, num_samples, p=init_dist)
|
||||
# Prepare the dataset
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat()
|
||||
# Reshape distribution via rejection sampling.
|
||||
dataset = dataset.apply(
|
||||
resampling.rejection_resample(
|
||||
class_func=lambda x: x,
|
||||
target_dist=target_dist,
|
||||
initial_dist=init_dist,
|
||||
seed=142))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(options)
|
||||
|
||||
resample_time = _time_resampling(
|
||||
data_np, target_dist, init_dist, num_to_sample=1000)
|
||||
wall_time = self.run_benchmark(
|
||||
dataset=dataset, num_elements=num_samples, iters=10, warmup=True)
|
||||
resample_time = wall_time * num_samples
|
||||
|
||||
self.report_benchmark(iters=1000, wall_time=resample_time, name="resample")
|
||||
self.report_benchmark(
|
||||
iters=10,
|
||||
wall_time=resample_time,
|
||||
name="resample_{}".format(num_samples))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
benchmark_base.test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user