use initializer op for resetting the iterator
This commit is contained in:
parent
f23830c990
commit
4a3d5de03e
@ -32,10 +32,10 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
"""Base class for dataset benchmarks."""
|
||||
|
||||
def _run_eager_benchmark(self, iterable, iters, warmup):
|
||||
"""Benchmark in eager mode.
|
||||
"""Benchmark the iterable in eager mode.
|
||||
|
||||
Runs the op `iters` times. In each iteration, the benchmark measures
|
||||
the time it takes to go execute the op.
|
||||
Runs the iterable `iters` times. In each iteration, the benchmark measures
|
||||
the time it takes to go execute the iterable.
|
||||
|
||||
Args:
|
||||
iterable: The tf op or tf.data Dataset to benchmark.
|
||||
@ -44,7 +44,7 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
|
||||
Returns:
|
||||
A float, representing the median time (with respect to `iters`)
|
||||
it takes for the op to be executed `iters` num of times.
|
||||
it takes for the iterable to be executed `iters` num of times.
|
||||
"""
|
||||
|
||||
deltas = []
|
||||
@ -69,11 +69,11 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
iters,
|
||||
warmup,
|
||||
session_config,
|
||||
is_dataset=True):
|
||||
"""Benchmarks the op in graph mode.
|
||||
initializer=None):
|
||||
"""Benchmarks the iterable in graph mode.
|
||||
|
||||
Runs the op `iters` times. In each iteration, the benchmark measures
|
||||
the time it takes to go execute the op.
|
||||
Runs the iterable `iters` times. In each iteration, the benchmark measures
|
||||
the time it takes to go execute the iterable.
|
||||
|
||||
Args:
|
||||
iterable: The tf op or tf.data Dataset to benchmark.
|
||||
@ -81,12 +81,12 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
warmup: If true, warms up the session caches by running an untimed run.
|
||||
session_config: A ConfigProto protocol buffer with configuration options
|
||||
for the session. Applicable only for benchmarking in graph mode.
|
||||
is_dataset: A boolean value representing whether the op is a tf.data
|
||||
Dataset or not.
|
||||
initializer: The initializer op required to initialize the
|
||||
iterable.
|
||||
|
||||
Returns:
|
||||
A float, representing the median time (with respect to `iters`)
|
||||
it takes for the op to be executed `iters` num of times.
|
||||
it takes for the iterable to be executed `iters` num of times.
|
||||
"""
|
||||
|
||||
deltas = []
|
||||
@ -94,25 +94,18 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
raise RuntimeError(
|
||||
"Graph mode benchmarking is not supported in eager mode.")
|
||||
|
||||
if is_dataset:
|
||||
iterator = dataset_ops.make_initializable_iterator(iterable)
|
||||
next_element = iterator.get_next()
|
||||
op = nest.flatten(next_element)[0].op
|
||||
else:
|
||||
op = iterable
|
||||
|
||||
for _ in range(iters):
|
||||
with session.Session(config=session_config) as sess:
|
||||
if warmup:
|
||||
if is_dataset:
|
||||
sess.run(iterator.initializer)
|
||||
# Run once to warm up the session caches.
|
||||
sess.run(op)
|
||||
if initializer:
|
||||
sess.run(initializer)
|
||||
sess.run(iterable)
|
||||
|
||||
if is_dataset:
|
||||
sess.run(iterator.initializer)
|
||||
if initializer:
|
||||
sess.run(initializer)
|
||||
start = time.time()
|
||||
sess.run(op)
|
||||
sess.run(iterable)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
return np.median(deltas)
|
||||
@ -151,8 +144,7 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
iterable=op,
|
||||
iters=iters,
|
||||
warmup=warmup,
|
||||
session_config=session_config,
|
||||
is_dataset=False
|
||||
session_config=session_config
|
||||
)
|
||||
|
||||
def run_benchmark(self,
|
||||
@ -206,11 +198,15 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
)
|
||||
return median_duration / float(num_elements)
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
next_element = iterator.get_next()
|
||||
op = nest.flatten(next_element)[0].op
|
||||
median_duration = self._run_graph_benchmark(
|
||||
iterable=dataset,
|
||||
iterable=op,
|
||||
iters=iters,
|
||||
warmup=warmup,
|
||||
session_config=session_config
|
||||
session_config=session_config,
|
||||
initializer=iterator.initializer
|
||||
)
|
||||
return median_duration / float(num_elements)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user