Merge pull request #47015 from kvignesh1420:eager-exp-bechmarks-p2

PiperOrigin-RevId: 356831159
Change-Id: I4503a76dec610fb0e5817110fa54f1645cf84ecb
This commit is contained in:
TensorFlower Gardener 2021-02-10 14:53:08 -08:00
commit cc2635d67c
7 changed files with 263 additions and 284 deletions

View File

@ -31,12 +31,116 @@ from tensorflow.python.platform import test
class DatasetBenchmarkBase(test.Benchmark):
"""Base class for dataset benchmarks."""
def _run_eager_benchmark(self, iterable, iters, warmup):
"""Benchmark the iterable in eager mode.
Runs the iterable `iters` times. In each iteration, the benchmark measures
the time it takes to go execute the iterable.
Args:
iterable: The tf op or tf.data Dataset to benchmark.
iters: Number of times to repeat the timing.
warmup: If true, warms up the session caches by running an untimed run.
Returns:
A float, representing the median time (with respect to `iters`)
it takes for the iterable to be executed `iters` num of times.
"""
deltas = []
if not context.executing_eagerly():
raise RuntimeError(
"Eager mode benchmarking is not supported in graph mode.")
for _ in range(iters):
if warmup:
iterator = iter(iterable)
next(iterator)
iterator = iter(iterable)
start = time.time()
next(iterator)
end = time.time()
deltas.append(end - start)
return np.median(deltas)
def _run_graph_benchmark(self,
iterable,
iters,
warmup,
session_config,
initializer=None):
"""Benchmarks the iterable in graph mode.
Runs the iterable `iters` times. In each iteration, the benchmark measures
the time it takes to go execute the iterable.
Args:
iterable: The tf op or tf.data Dataset to benchmark.
iters: Number of times to repeat the timing.
warmup: If true, warms up the session caches by running an untimed run.
session_config: A ConfigProto protocol buffer with configuration options
for the session. Applicable only for benchmarking in graph mode.
initializer: The initializer op required to initialize the iterable.
Returns:
A float, representing the median time (with respect to `iters`)
it takes for the iterable to be executed `iters` num of times.
"""
deltas = []
if context.executing_eagerly():
raise RuntimeError(
"Graph mode benchmarking is not supported in eager mode.")
for _ in range(iters):
with session.Session(config=session_config) as sess:
if warmup:
# Run once to warm up the session caches.
if initializer:
sess.run(initializer)
sess.run(iterable)
if initializer:
sess.run(initializer)
start = time.time()
sess.run(iterable)
end = time.time()
deltas.append(end - start)
return np.median(deltas)
def run_op_benchmark(self, op, iters=1, warmup=True, session_config=None):
"""Benchmarks the op.
Runs the op `iters` times. In each iteration, the benchmark measures
the time it takes to go execute the op.
Args:
op: The tf op to benchmark.
iters: Number of times to repeat the timing.
warmup: If true, warms up the session caches by running an untimed run.
session_config: A ConfigProto protocol buffer with configuration options
for the session. Applicable only for benchmarking in graph mode.
Returns:
A float, representing the per-execution wall time of the op in seconds.
This is the median time (with respect to `iters`) it takes for the op
to be executed `iters` num of times.
"""
if context.executing_eagerly():
return self._run_eager_benchmark(iterable=op, iters=iters, warmup=warmup)
return self._run_graph_benchmark(
iterable=op, iters=iters, warmup=warmup, session_config=session_config)
def run_benchmark(self,
dataset,
num_elements,
iters=1,
warmup=True,
apply_default_optimizations=False):
apply_default_optimizations=False,
session_config=None):
"""Benchmarks the dataset.
Runs the dataset `iters` times. In each iteration, the benchmark measures
@ -50,6 +154,8 @@ class DatasetBenchmarkBase(test.Benchmark):
warmup: If true, warms up the session caches by running an untimed run.
apply_default_optimizations: Determines whether default optimizations
should be applied.
session_config: A ConfigProto protocol buffer with configuration options
for the session. Applicable only for benchmarking in graph mode.
Returns:
A float, representing the per-element wall time of the dataset in seconds.
@ -70,38 +176,22 @@ class DatasetBenchmarkBase(test.Benchmark):
# to execute upstream computation. If it is optimized in the future,
# we will have to change this code.
dataset = dataset.skip(num_elements - 1)
deltas = []
if context.executing_eagerly():
for _ in range(iters):
if warmup:
iterator = iter(dataset)
next(iterator)
iterator = iter(dataset)
start = time.time()
next(iterator)
end = time.time()
deltas.append(end - start)
return np.median(deltas) / float(num_elements)
median_duration = self._run_eager_benchmark(
iterable=dataset, iters=iters, warmup=warmup)
return median_duration / float(num_elements)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
next_element = nest.flatten(next_element)[0]
for _ in range(iters):
with session.Session() as sess:
if warmup:
# Run once to warm up the session caches.
sess.run(iterator.initializer)
sess.run(next_element.op)
sess.run(iterator.initializer)
start = time.time()
sess.run(next_element.op)
end = time.time()
deltas.append(end - start)
return np.median(deltas) / float(num_elements)
op = nest.flatten(next_element)[0].op
median_duration = self._run_graph_benchmark(
iterable=op,
iters=iters,
warmup=warmup,
session_config=session_config,
initializer=iterator.initializer)
return median_duration / float(num_elements)
def run_and_report_benchmark(self,
dataset,
@ -110,7 +200,8 @@ class DatasetBenchmarkBase(test.Benchmark):
iters=5,
extras=None,
warmup=True,
apply_default_optimizations=False):
apply_default_optimizations=False,
session_config=None):
"""Benchmarks the dataset and reports the stats.
Runs the dataset `iters` times. In each iteration, the benchmark measures
@ -127,6 +218,8 @@ class DatasetBenchmarkBase(test.Benchmark):
warmup: If true, warms up the session caches by running an untimed run.
apply_default_optimizations: Determines whether default optimizations
should be applied.
session_config: A ConfigProto protocol buffer with configuration options
for the session. Applicable only for benchmarking in graph mode.
Returns:
A float, representing the per-element wall time of the dataset in seconds.
@ -138,7 +231,8 @@ class DatasetBenchmarkBase(test.Benchmark):
num_elements=num_elements,
iters=iters,
warmup=warmup,
apply_default_optimizations=apply_default_optimizations)
apply_default_optimizations=apply_default_optimizations,
session_config=session_config)
if context.executing_eagerly():
name = "{}.eager".format(name)
else:

View File

@ -17,18 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class AutotuneBenchmark(test.Benchmark):
class AutotuneBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for autotuning performance knobs."""
def _run_benchmark(self, dataset, autotune, autotune_buffers,
@ -38,30 +35,17 @@ class AutotuneBenchmark(test.Benchmark):
options.experimental_optimization.autotune = autotune
options.experimental_optimization.autotune_buffers = autotune_buffers
dataset = dataset.with_options(options)
iterator = dataset_ops.make_one_shot_iterator(dataset)
get_next = iterator.get_next()
# Run the op directly to avoid copying the tensor to python.
get_next_op = nest.flatten(get_next)[0].op
deltas = []
with session.Session() as sess:
for _ in range(5):
sess.run(get_next_op)
for _ in range(benchmark_iters):
start = time.time()
sess.run(get_next_op)
end = time.time()
deltas.append(end - start)
autotune_string = "_autotune_{}".format(
"parallelism_and_buffer_sizes"
if autotune_buffers else "parallelism_only")
self.report_benchmark(
wall_time = self.run_and_report_benchmark(
dataset=dataset,
num_elements=1,
warmup=True,
iters=benchmark_iters,
wall_time=np.median(deltas),
name=benchmark_label + (autotune_string if autotune else ""))
return np.median(deltas)
return wall_time
def benchmark_batch(self):
a = self._benchmark_batch(autotune=False)
@ -101,9 +85,9 @@ class AutotuneBenchmark(test.Benchmark):
dataset = dataset.map(
math_ops.matmul, num_parallel_calls=dataset_ops.AUTOTUNE)
return self._run_benchmark(
dataset,
autotune,
autotune_buffers,
dataset=dataset,
autotune=autotune,
autotune_buffers=autotune_buffers,
benchmark_iters=10000,
benchmark_label="map")
@ -124,9 +108,9 @@ class AutotuneBenchmark(test.Benchmark):
math_ops.matmul, num_parallel_calls=dataset_ops.AUTOTUNE)
dataset = dataset.batch(batch_size=batch_size)
return self._run_benchmark(
dataset,
autotune,
autotune_buffers,
dataset=dataset,
autotune=autotune,
autotune_buffers=autotune_buffers,
benchmark_iters=1000,
benchmark_label="map_and_batch")
@ -148,9 +132,9 @@ class AutotuneBenchmark(test.Benchmark):
cycle_length=10,
num_parallel_calls=dataset_ops.AUTOTUNE)
return self._run_benchmark(
dataset,
autotune,
autotune_buffers,
dataset=dataset,
autotune=autotune,
autotune_buffers=autotune_buffers,
benchmark_iters=10000,
benchmark_label="interleave")
@ -196,9 +180,9 @@ class AutotuneBenchmark(test.Benchmark):
dataset = dataset_ops.Dataset.zip((dataset, dataset_c))
dataset = dataset.map(f2, num_parallel_calls=dataset_ops.AUTOTUNE)
return self._run_benchmark(
dataset,
autotune,
autotune_buffers,
dataset=dataset,
autotune=autotune,
autotune_buffers=autotune_buffers,
benchmark_iters=10000,
benchmark_label="map_and_interleave")
@ -244,12 +228,12 @@ class AutotuneBenchmark(test.Benchmark):
dataset_c = dataset_c.batch(batch_size=batch_size)
dataset = dataset_ops.Dataset.zip((dataset, dataset_c))
return self._run_benchmark(
dataset,
autotune,
autotune_buffers,
dataset=dataset,
autotune=autotune,
autotune_buffers=autotune_buffers,
benchmark_iters=1000,
benchmark_label="map_batch_and_interleave")
if __name__ == "__main__":
test.main()
benchmark_base.test.main()

View File

@ -17,18 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
# TODO(b/119837791): Add eager benchmarks too.
class ChooseFastestBenchmark(test.Benchmark):
class ChooseFastestBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for static optimizations."""
def benchmark_choose_fastest(self):
@ -42,9 +36,9 @@ class ChooseFastestBenchmark(test.Benchmark):
merge_dataset = optimization._ChooseFastestDataset( # pylint: disable=protected-access
[batch_map_dataset, map_batch_dataset])
self._benchmark(map_batch_dataset, "map_batch_dataset")
self._benchmark(batch_map_dataset, "batch_map_dataset")
self._benchmark(merge_dataset, "merge_dataset")
self._benchmark(dataset=map_batch_dataset, name="map_batch_dataset")
self._benchmark(dataset=batch_map_dataset, name="batch_map_dataset")
self._benchmark(dataset=merge_dataset, name="merge_dataset")
def benchmark_choose_fastest_first_n_iterations(self):
@ -58,48 +52,24 @@ class ChooseFastestBenchmark(test.Benchmark):
merge_dataset = optimization._ChooseFastestDataset( # pylint: disable=protected-access
[batch_map_dataset, map_batch_dataset])
self._benchmark_first_n(map_batch_dataset, "map_batch_dataset")
self._benchmark_first_n(batch_map_dataset, "batch_map_dataset")
self._benchmark_first_n(merge_dataset, "merge_dataset")
self._benchmark_first_n(dataset=map_batch_dataset, name="map_batch_dataset")
self._benchmark_first_n(dataset=batch_map_dataset, name="batch_map_dataset")
self._benchmark_first_n(dataset=merge_dataset, name="merge_dataset")
def _benchmark_first_n(self, dataset, name):
n = 10 # The default num_experiments for ChooseFastestDataset
iterator = dataset_ops.make_one_shot_iterator(dataset)
next_element = iterator.get_next()
deltas = []
for _ in range(100):
with session.Session() as sess:
start = time.time()
for _ in range(n):
sess.run(next_element.op)
end = time.time()
deltas.append(end - start)
median_wall_time = np.median(deltas) / n
self.report_benchmark(
iters=n, wall_time=median_wall_time, name=name + "_first_%d" % n)
self.run_and_report_benchmark(
dataset=dataset,
num_elements=n,
iters=100,
warmup=True,
name=name + "_first_%d" % n)
def _benchmark(self, dataset, name):
iterator = dataset_ops.make_one_shot_iterator(dataset)
next_element = iterator.get_next()
with session.Session() as sess:
# Run 10 steps to warm up the session caches before taking the first
# measurement. Additionally, 10 is the default num_experiments for
# ChooseFastestDataset.
for _ in range(10):
sess.run(next_element.op)
deltas = []
for _ in range(50):
start = time.time()
for _ in range(50):
sess.run(next_element.op)
end = time.time()
deltas.append(end - start)
median_wall_time = np.median(deltas) / 100
self.report_benchmark(iters=100, wall_time=median_wall_time, name=name)
self.run_and_report_benchmark(
dataset=dataset, num_elements=100, iters=100, warmup=True, name=name)
if __name__ == "__main__":
test.main()
benchmark_base.test.main()

View File

@ -19,25 +19,21 @@ from __future__ import print_function
import hashlib
import itertools
import time
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
_NUMPY_RANDOM_SEED = 42
class MapAndBatchBenchmark(test.Benchmark):
class MapAndBatchBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for `tf.data.experimental.map_and_batch()`."""
def benchmark_map_and_batch(self):
@ -45,53 +41,23 @@ class MapAndBatchBenchmark(test.Benchmark):
shapes = [(), (10,), (10, 10), (10, 10, 10), (224, 224, 3)]
batch_size_values = [1, 32, 64, 128, 1024]
shape_placeholder = array_ops.placeholder(dtypes.int64, shape=[None])
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
dataset = dataset_ops.Dataset.range(1000000000)
dense_value = random_ops.random_normal(shape=shape_placeholder)
dataset = dataset.apply(batching.map_and_batch(
lambda _: dense_value, batch_size_placeholder))
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
for shape in shapes:
for batch_size in batch_size_values:
with session.Session() as sess:
sess.run(iterator.initializer, feed_dict={
shape_placeholder: shape, batch_size_placeholder: batch_size})
dataset = dataset_ops.Dataset.range(1000000000)
dense_value = random_ops.random_normal(shape=shape)
# Use a C++ callable to minimize the Python overhead in the benchmark.
callable_opts = config_pb2.CallableOptions()
callable_opts.target.append(next_element.op.name)
op_callable = sess._make_callable_from_options(callable_opts) # pylint: disable=protected-access
dataset = dataset.apply(
batching.map_and_batch(lambda _: dense_value, batch_size))
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
# Run five steps to warm up the session caches before taking the
# first measurement.
for _ in range(5):
op_callable()
deltas = []
overall_start = time.time()
# Run at least five repetitions and for at least five seconds.
while len(deltas) < 5 or time.time() - overall_start < 5.0:
start = time.time()
for _ in range(100):
op_callable()
end = time.time()
deltas.append(end - start)
del op_callable
median_wall_time = np.median(deltas) / 100.0
iters = len(deltas) * 100
self.report_benchmark(
iters=iters, wall_time=median_wall_time,
self.run_and_report_benchmark(
dataset=dataset,
num_elements=batch_size,
iters=100,
warmup=True,
name="num_elements_%d_batch_size_%d" % (np.prod(shape), batch_size))
def benchmark_map_and_batch_chaining_versus_fusing(self):
@ -143,24 +109,15 @@ class MapAndBatchBenchmark(test.Benchmark):
(element_size * batch_size) // min(num_calls, inter_op))
# By default the chained map().batch() calls will not be fused.
chained_dataset = make_dataset(element_size, num_calls, batch_size)
chained_iterator = dataset_ops.make_one_shot_iterator(chained_dataset)
chained_get_next = chained_iterator.get_next()
chained_deltas = []
with session.Session(
config=config_pb2.ConfigProto(
inter_op_parallelism_threads=inter_op,
use_per_session_threads=True)) as sess:
for _ in range(5):
sess.run(chained_get_next.op)
for _ in range(num_iters):
start = time.time()
sess.run(chained_get_next.op)
end = time.time()
chained_deltas.append(end - start)
session_config = config_pb2.ConfigProto(
inter_op_parallelism_threads=inter_op, use_per_session_threads=True)
self.report_benchmark(
self.run_and_report_benchmark(
dataset=chained_dataset,
iters=num_iters,
wall_time=np.median(chained_deltas),
num_elements=batch_size,
warmup=True,
session_config=session_config,
name=name("chained", label, num_calls, inter_op, element_size,
batch_size))
@ -168,30 +125,16 @@ class MapAndBatchBenchmark(test.Benchmark):
options = dataset_ops.Options()
options.experimental_optimization.map_and_batch_fusion = True
fused_dataset = chained_dataset.with_options(options)
fused_iterator = dataset_ops.make_one_shot_iterator(fused_dataset)
fused_get_next = fused_iterator.get_next()
fused_deltas = []
with session.Session(
config=config_pb2.ConfigProto(
inter_op_parallelism_threads=inter_op,
use_per_session_threads=True)) as sess:
for _ in range(5):
sess.run(fused_get_next.op)
for _ in range(num_iters):
start = time.time()
sess.run(fused_get_next.op)
end = time.time()
fused_deltas.append(end - start)
self.report_benchmark(
self.run_and_report_benchmark(
dataset=fused_dataset,
iters=num_iters,
wall_time=np.median(fused_deltas),
num_elements=batch_size,
warmup=True,
session_config=session_config,
name=name("fused", label, num_calls, inter_op, element_size,
batch_size))
print()
np.random.seed(_NUMPY_RANDOM_SEED)
benchmark("Sequential element size evaluation", seq_elem_size_series)
benchmark("Sequential batch size evaluation", seq_batch_size_series)
@ -202,4 +145,4 @@ class MapAndBatchBenchmark(test.Benchmark):
if __name__ == "__main__":
test.main()
benchmark_base.test.main()

View File

@ -17,37 +17,30 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import map_defun
from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
# TODO(b/119837791): Add eager benchmarks too.
class MapDefunBenchmark(test.Benchmark):
class MapDefunBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for MapDefunOp."""
def _run(self, op, name=None, num_iters=3000):
with session.Session() as sess:
for _ in range(5):
sess.run(op)
start = time.time()
for _ in range(num_iters):
sess.run(op)
end = time.time()
mean_us = (end - start) * 1e6 / num_iters
self.report_benchmark(
name=name,
iters=num_iters,
wall_time=mean_us,
extras={"examples_per_sec": num_iters / (end - start)})
wall_time = self.run_op_benchmark(op=op, iters=num_iters, warmup=True)
zero_division_delta = 1e-100
wall_time = wall_time + zero_division_delta
self.report_benchmark(
name=name,
iters=num_iters,
wall_time=wall_time,
extras={"examples_per_sec": 1 / float(wall_time)})
def benchmark_defun_vs_map_fn(self):
"""Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
@ -59,17 +52,21 @@ class MapDefunBenchmark(test.Benchmark):
def fn(x):
return array_ops.identity(x)
base = math_ops.range(100)
base = math_ops.range(10000)
for input_size in [10, 100, 1000, 10000]:
num_iters = 100000 // input_size
num_iters = 10000 // input_size
map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
map_fn_op = map_fn.map_fn(fn, base)
self._run(
map_defun_op, "with_defun_size_%d" % input_size, num_iters=num_iters)
op=map_defun_op,
name="with_defun_size_%d" % input_size,
num_iters=num_iters)
self._run(
map_fn_op, "without_defun_size_%d" % input_size, num_iters=num_iters)
op=map_fn_op,
name="without_defun_size_%d" % input_size,
num_iters=num_iters)
if __name__ == "__main__":
test.main()
benchmark_base.test.main()

View File

@ -17,16 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
NON_PARALLEL = "non_parallel"
EXPERIMENTAL_PARALLEL = "experimental_parallel"
@ -65,7 +61,7 @@ def _make_fake_dataset_fn(initial_delay_us, remainder_delay_us):
return fake_dataset_fn
class ParallelInterleaveBenchmark(test.Benchmark):
class ParallelInterleaveBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for `tf.data.experimental.parallel_interleave()`."""
def apply_interleave(self, interleave_version, dataset, interleave_fn,
@ -94,8 +90,12 @@ class ParallelInterleaveBenchmark(test.Benchmark):
num_parallel_calls=None):
dataset = dataset_ops.Dataset.range(1).repeat()
interleave_fn = _make_fake_dataset_fn(initial_delay, remainder_delay)
return self.apply_interleave(interleave_version, dataset, interleave_fn,
cycle_length, num_parallel_calls)
return self.apply_interleave(
interleave_version=interleave_version,
dataset=dataset,
interleave_fn=interleave_fn,
cycle_length=cycle_length,
num_parallel_calls=num_parallel_calls)
def _benchmark(self,
interleave_version,
@ -107,26 +107,29 @@ class ParallelInterleaveBenchmark(test.Benchmark):
num_parallel_calls=None,
attach_stats_aggregator=False,
name=None):
ds = self.make_dataset(interleave_version, initial_delay_us,
remainder_delay_us, cycle_length, num_parallel_calls)
dataset = self.make_dataset(
interleave_version=interleave_version,
initial_delay=initial_delay_us,
remainder_delay=remainder_delay_us,
cycle_length=cycle_length,
num_parallel_calls=num_parallel_calls)
if attach_stats_aggregator:
aggregator = stats_aggregator.StatsAggregator()
opts = dataset_ops.Options()
opts.experimental_stats.aggregator = aggregator
ds = ds.with_options(opts)
dataset = dataset.with_options(opts)
ds = ds.skip(num_elements)
deltas = []
for _ in range(iters):
start = time.time()
next(iter(ds))
deltas.append(time.time() - start)
self.report_benchmark(iters=iters, wall_time=np.median(deltas), name=name)
self.run_and_report_benchmark(
dataset=dataset,
num_elements=num_elements,
iters=iters,
warmup=True,
name=name)
def benchmark_remote_file_simulation(self):
for version in [EXPERIMENTAL_PARALLEL, CORE_PARALLEL]:
self._benchmark(
version,
interleave_version=version,
initial_delay_us=100 * 1000,
remainder_delay_us=1000,
num_elements=5000,
@ -135,14 +138,16 @@ class ParallelInterleaveBenchmark(test.Benchmark):
def benchmark_fast_input(self):
for version in [EXPERIMENTAL_PARALLEL, CORE_PARALLEL]:
self._benchmark(
version, num_elements=200000, name="fast_input_" + version)
interleave_version=version,
num_elements=200000,
name="fast_input_" + version)
# Measure the overhead of parallel interleaves compared to non-parallel
# interleave.
def benchmark_single_cycle(self):
for version in [NON_PARALLEL, EXPERIMENTAL_PARALLEL, CORE_PARALLEL]:
self._benchmark(
version,
interleave_version=version,
cycle_length=1,
num_elements=200000,
name="single_cycle_" + version)
@ -151,7 +156,7 @@ class ParallelInterleaveBenchmark(test.Benchmark):
# cannot be compared here because it sets num_parallel_calls = cycle_length.
def benchmark_single_parallel_call(self):
self._benchmark(
CORE_PARALLEL,
interleave_version=CORE_PARALLEL,
num_elements=200000,
num_parallel_calls=1,
name="single_parallel_call_" + CORE_PARALLEL)
@ -159,14 +164,14 @@ class ParallelInterleaveBenchmark(test.Benchmark):
def benchmark_long_cycle(self):
for version in [EXPERIMENTAL_PARALLEL, CORE_PARALLEL]:
self._benchmark(
version,
interleave_version=version,
cycle_length=1000,
num_elements=100000,
name="long_cycle_" + version)
def benchmark_stats(self):
self._benchmark(
CORE_PARALLEL,
interleave_version=CORE_PARALLEL,
cycle_length=50,
num_elements=1000,
name="stats",
@ -174,5 +179,4 @@ class ParallelInterleaveBenchmark(test.Benchmark):
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()
benchmark_base.test.main()

View File

@ -17,43 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import resampling
from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
def _time_resampling(data_np, target_dist, init_dist, num_to_sample): # pylint: disable=missing-docstring
dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat()
# Reshape distribution via rejection sampling.
dataset = dataset.apply(
resampling.rejection_resample(
class_func=lambda x: x,
target_dist=target_dist,
initial_dist=init_dist,
seed=142))
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
get_next = dataset_ops.make_one_shot_iterator(dataset).get_next()
with session.Session() as sess:
start_time = time.time()
for _ in xrange(num_to_sample):
sess.run(get_next)
end_time = time.time()
return end_time - start_time
class RejectionResampleBenchmark(test.Benchmark):
class RejectionResampleBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for `tf.data.experimental.rejection_resample()`."""
def benchmark_resample_performance(self):
@ -63,12 +34,28 @@ class RejectionResampleBenchmark(test.Benchmark):
# We don't need many samples to test a dirac-delta target distribution
num_samples = 1000
data_np = np.random.choice(num_classes, num_samples, p=init_dist)
# Prepare the dataset
dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat()
# Reshape distribution via rejection sampling.
dataset = dataset.apply(
resampling.rejection_resample(
class_func=lambda x: x,
target_dist=target_dist,
initial_dist=init_dist,
seed=142))
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
resample_time = _time_resampling(
data_np, target_dist, init_dist, num_to_sample=1000)
wall_time = self.run_benchmark(
dataset=dataset, num_elements=num_samples, iters=10, warmup=True)
resample_time = wall_time * num_samples
self.report_benchmark(iters=1000, wall_time=resample_time, name="resample")
self.report_benchmark(
iters=10,
wall_time=resample_time,
name="resample_{}".format(num_samples))
if __name__ == "__main__":
test.main()
benchmark_base.test.main()