refactor eager and graph benchmark methods
This commit is contained in:
parent
4ef2344b3a
commit
bbbdd5327d
@ -31,10 +31,97 @@ from tensorflow.python.platform import test
|
|||||||
class DatasetBenchmarkBase(test.Benchmark):
|
class DatasetBenchmarkBase(test.Benchmark):
|
||||||
"""Base class for dataset benchmarks."""
|
"""Base class for dataset benchmarks."""
|
||||||
|
|
||||||
|
def _run_eager_benchmark(self, op_or_dataset, iters, warmup):
|
||||||
|
"""Benchmark in eager mode.
|
||||||
|
|
||||||
|
Runs the op `iters` times. In each iteration, the benchmark measures
|
||||||
|
the time it takes to go execute the op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op_or_dataset: The tf op or tf.data Dataset to benchmark.
|
||||||
|
iters: Number of times to repeat the timing.
|
||||||
|
warmup: If true, warms up the session caches by running an untimed run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A float, representing the median time (with respect to `iters`)
|
||||||
|
it takes for the op to be executed `iters` num of times.
|
||||||
|
"""
|
||||||
|
|
||||||
|
deltas = []
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
raise RuntimeError(
|
||||||
|
"Eager mode benchmarking is not supported in graph mode.")
|
||||||
|
|
||||||
|
for _ in range(iters):
|
||||||
|
if warmup:
|
||||||
|
iterator = iter(op_or_dataset)
|
||||||
|
next(iterator)
|
||||||
|
|
||||||
|
iterator = iter(op_or_dataset)
|
||||||
|
start = time.time()
|
||||||
|
next(iterator)
|
||||||
|
end = time.time()
|
||||||
|
deltas.append(end - start)
|
||||||
|
return np.median(deltas)
|
||||||
|
|
||||||
|
def _run_graph_benchmark(self,
|
||||||
|
op_or_dataset,
|
||||||
|
iters,
|
||||||
|
warmup,
|
||||||
|
session_config,
|
||||||
|
is_dataset=True):
|
||||||
|
"""Benchmarks the op in graph mode.
|
||||||
|
|
||||||
|
Runs the op `iters` times. In each iteration, the benchmark measures
|
||||||
|
the time it takes to go execute the op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op_or_dataset: The tf op or tf.data Dataset to benchmark.
|
||||||
|
iters: Number of times to repeat the timing.
|
||||||
|
warmup: If true, warms up the session caches by running an untimed run.
|
||||||
|
session_config: A ConfigProto protocol buffer with configuration options
|
||||||
|
for the session. Applicable only for benchmarking in graph mode.
|
||||||
|
is_dataset: A boolean value representing whether the op is a tf.data
|
||||||
|
Dataset or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A float, representing the median time (with respect to `iters`)
|
||||||
|
it takes for the op to be executed `iters` num of times.
|
||||||
|
"""
|
||||||
|
|
||||||
|
deltas = []
|
||||||
|
if context.executing_eagerly():
|
||||||
|
raise RuntimeError(
|
||||||
|
"Graph mode benchmarking is not supported in eager mode.")
|
||||||
|
|
||||||
|
if is_dataset:
|
||||||
|
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||||
|
next_element = iterator.get_next()
|
||||||
|
op = nest.flatten(next_element)[0].op
|
||||||
|
else:
|
||||||
|
op = op_or_dataset
|
||||||
|
|
||||||
|
for _ in range(iters):
|
||||||
|
with session.Session(config=session_config) as sess:
|
||||||
|
if warmup:
|
||||||
|
if is_dataset:
|
||||||
|
sess.run(iterator.initializer)
|
||||||
|
# Run once to warm up the session caches.
|
||||||
|
sess.run(op)
|
||||||
|
|
||||||
|
if is_dataset:
|
||||||
|
sess.run(iterator.initializer)
|
||||||
|
start = time.time()
|
||||||
|
sess.run(op)
|
||||||
|
end = time.time()
|
||||||
|
deltas.append(end - start)
|
||||||
|
return np.median(deltas)
|
||||||
|
|
||||||
def run_op_benchmark(self,
|
def run_op_benchmark(self,
|
||||||
op,
|
op,
|
||||||
iters=1,
|
iters=1,
|
||||||
warmup=True):
|
warmup=True,
|
||||||
|
session_config=None):
|
||||||
"""Benchmarks the op.
|
"""Benchmarks the op.
|
||||||
|
|
||||||
Runs the op `iters` times. In each iteration, the benchmark measures
|
Runs the op `iters` times. In each iteration, the benchmark measures
|
||||||
@ -44,6 +131,8 @@ class DatasetBenchmarkBase(test.Benchmark):
|
|||||||
op: The tf op to benchmark.
|
op: The tf op to benchmark.
|
||||||
iters: Number of times to repeat the timing.
|
iters: Number of times to repeat the timing.
|
||||||
warmup: If true, warms up the session caches by running an untimed run.
|
warmup: If true, warms up the session caches by running an untimed run.
|
||||||
|
session_config: A ConfigProto protocol buffer with configuration options
|
||||||
|
for the session. Applicable only for benchmarking in graph mode.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A float, representing the per-execution wall time of the op in seconds.
|
A float, representing the per-execution wall time of the op in seconds.
|
||||||
@ -52,27 +141,19 @@ class DatasetBenchmarkBase(test.Benchmark):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
for _ in range(iters):
|
return self._run_eager_benchmark(
|
||||||
if warmup:
|
op_or_dataset=op,
|
||||||
iterator = iter(op)
|
iters=iters,
|
||||||
next(iterator)
|
warmup=warmup
|
||||||
|
)
|
||||||
|
|
||||||
iterator = iter(op)
|
return self._run_graph_benchmark(
|
||||||
start = time.time()
|
op_or_dataset=op,
|
||||||
next(iterator)
|
iters=iters,
|
||||||
end = time.time()
|
warmup=warmup,
|
||||||
return (end - start) / iters
|
session_config=session_config,
|
||||||
|
is_dataset=False
|
||||||
with session.Session() as sess:
|
)
|
||||||
if warmup:
|
|
||||||
# Run once to warm up the session caches.
|
|
||||||
sess.run(op)
|
|
||||||
start = time.time()
|
|
||||||
for _ in range(iters):
|
|
||||||
sess.run(op)
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
return (end - start) / iters
|
|
||||||
|
|
||||||
def run_benchmark(self,
|
def run_benchmark(self,
|
||||||
dataset,
|
dataset,
|
||||||
@ -116,38 +197,22 @@ class DatasetBenchmarkBase(test.Benchmark):
|
|||||||
# to execute upstream computation. If it is optimized in the future,
|
# to execute upstream computation. If it is optimized in the future,
|
||||||
# we will have to change this code.
|
# we will have to change this code.
|
||||||
dataset = dataset.skip(num_elements - 1)
|
dataset = dataset.skip(num_elements - 1)
|
||||||
deltas = []
|
|
||||||
|
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
for _ in range(iters):
|
median_duration = self._run_eager_benchmark(
|
||||||
if warmup:
|
op_or_dataset=dataset,
|
||||||
iterator = iter(dataset)
|
iters=iters,
|
||||||
next(iterator)
|
warmup=warmup
|
||||||
|
)
|
||||||
|
return median_duration / float(num_elements)
|
||||||
|
|
||||||
iterator = iter(dataset)
|
median_duration = self._run_graph_benchmark(
|
||||||
start = time.time()
|
op_or_dataset=dataset,
|
||||||
next(iterator)
|
iters=iters,
|
||||||
end = time.time()
|
warmup=warmup,
|
||||||
deltas.append(end - start)
|
session_config=session_config
|
||||||
return np.median(deltas) / float(num_elements)
|
)
|
||||||
|
return median_duration / float(num_elements)
|
||||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
|
||||||
next_element = iterator.get_next()
|
|
||||||
next_element = nest.flatten(next_element)[0]
|
|
||||||
|
|
||||||
for _ in range(iters):
|
|
||||||
with session.Session(config=session_config) as sess:
|
|
||||||
if warmup:
|
|
||||||
# Run once to warm up the session caches.
|
|
||||||
sess.run(iterator.initializer)
|
|
||||||
sess.run(next_element.op)
|
|
||||||
|
|
||||||
sess.run(iterator.initializer)
|
|
||||||
start = time.time()
|
|
||||||
sess.run(next_element.op)
|
|
||||||
end = time.time()
|
|
||||||
deltas.append(end - start)
|
|
||||||
return np.median(deltas) / float(num_elements)
|
|
||||||
|
|
||||||
def run_and_report_benchmark(self,
|
def run_and_report_benchmark(self,
|
||||||
dataset,
|
dataset,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user