refactor eager and graph benchmark methods

This commit is contained in:
Vignesh Kothapalli 2021-02-09 17:00:06 +05:30
parent 4ef2344b3a
commit bbbdd5327d
No known key found for this signature in database
GPG Key ID: 7C9D6956FBA21DD7

View File

@ -31,10 +31,97 @@ from tensorflow.python.platform import test
class DatasetBenchmarkBase(test.Benchmark):
"""Base class for dataset benchmarks."""
def _run_eager_benchmark(self, op_or_dataset, iters, warmup):
"""Benchmark in eager mode.
Runs the op `iters` times. In each iteration, the benchmark measures
the time it takes to go execute the op.
Args:
op_or_dataset: The tf op or tf.data Dataset to benchmark.
iters: Number of times to repeat the timing.
warmup: If true, warms up the session caches by running an untimed run.
Returns:
A float, representing the median time (with respect to `iters`)
it takes for the op to be executed `iters` num of times.
"""
deltas = []
if not context.executing_eagerly():
raise RuntimeError(
"Eager mode benchmarking is not supported in graph mode.")
for _ in range(iters):
if warmup:
iterator = iter(op_or_dataset)
next(iterator)
iterator = iter(op_or_dataset)
start = time.time()
next(iterator)
end = time.time()
deltas.append(end - start)
return np.median(deltas)
def _run_graph_benchmark(self,
op_or_dataset,
iters,
warmup,
session_config,
is_dataset=True):
"""Benchmarks the op in graph mode.
Runs the op `iters` times. In each iteration, the benchmark measures
the time it takes to go execute the op.
Args:
op_or_dataset: The tf op or tf.data Dataset to benchmark.
iters: Number of times to repeat the timing.
warmup: If true, warms up the session caches by running an untimed run.
session_config: A ConfigProto protocol buffer with configuration options
for the session. Applicable only for benchmarking in graph mode.
is_dataset: A boolean value representing whether the op is a tf.data
Dataset or not.
Returns:
A float, representing the median time (with respect to `iters`)
it takes for the op to be executed `iters` num of times.
"""
deltas = []
if context.executing_eagerly():
raise RuntimeError(
"Graph mode benchmarking is not supported in eager mode.")
if is_dataset:
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
op = nest.flatten(next_element)[0].op
else:
op = op_or_dataset
for _ in range(iters):
with session.Session(config=session_config) as sess:
if warmup:
if is_dataset:
sess.run(iterator.initializer)
# Run once to warm up the session caches.
sess.run(op)
if is_dataset:
sess.run(iterator.initializer)
start = time.time()
sess.run(op)
end = time.time()
deltas.append(end - start)
return np.median(deltas)
def run_op_benchmark(self,
op,
iters=1,
warmup=True):
warmup=True,
session_config=None):
"""Benchmarks the op.
Runs the op `iters` times. In each iteration, the benchmark measures
@ -44,6 +131,8 @@ class DatasetBenchmarkBase(test.Benchmark):
op: The tf op to benchmark.
iters: Number of times to repeat the timing.
warmup: If true, warms up the session caches by running an untimed run.
session_config: A ConfigProto protocol buffer with configuration options
for the session. Applicable only for benchmarking in graph mode.
Returns:
A float, representing the per-execution wall time of the op in seconds.
@ -52,27 +141,19 @@ class DatasetBenchmarkBase(test.Benchmark):
"""
if context.executing_eagerly():
for _ in range(iters):
if warmup:
iterator = iter(op)
next(iterator)
return self._run_eager_benchmark(
op_or_dataset=op,
iters=iters,
warmup=warmup
)
iterator = iter(op)
start = time.time()
next(iterator)
end = time.time()
return (end - start) / iters
with session.Session() as sess:
if warmup:
# Run once to warm up the session caches.
sess.run(op)
start = time.time()
for _ in range(iters):
sess.run(op)
end = time.time()
return (end - start) / iters
return self._run_graph_benchmark(
op_or_dataset=op,
iters=iters,
warmup=warmup,
session_config=session_config,
is_dataset=False
)
def run_benchmark(self,
dataset,
@ -116,38 +197,22 @@ class DatasetBenchmarkBase(test.Benchmark):
# to execute upstream computation. If it is optimized in the future,
# we will have to change this code.
dataset = dataset.skip(num_elements - 1)
deltas = []
if context.executing_eagerly():
for _ in range(iters):
if warmup:
iterator = iter(dataset)
next(iterator)
median_duration = self._run_eager_benchmark(
op_or_dataset=dataset,
iters=iters,
warmup=warmup
)
return median_duration / float(num_elements)
iterator = iter(dataset)
start = time.time()
next(iterator)
end = time.time()
deltas.append(end - start)
return np.median(deltas) / float(num_elements)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
next_element = nest.flatten(next_element)[0]
for _ in range(iters):
with session.Session(config=session_config) as sess:
if warmup:
# Run once to warm up the session caches.
sess.run(iterator.initializer)
sess.run(next_element.op)
sess.run(iterator.initializer)
start = time.time()
sess.run(next_element.op)
end = time.time()
deltas.append(end - start)
return np.median(deltas) / float(num_elements)
median_duration = self._run_graph_benchmark(
op_or_dataset=dataset,
iters=iters,
warmup=warmup,
session_config=session_config
)
return median_duration / float(num_elements)
def run_and_report_benchmark(self,
dataset,