use initializer op for resetting the iterator

This commit is contained in:
Vignesh Kothapalli 2021-02-10 00:51:24 +05:30
parent f23830c990
commit 4a3d5de03e
No known key found for this signature in database
GPG Key ID: 7C9D6956FBA21DD7

View File

@ -32,10 +32,10 @@ class DatasetBenchmarkBase(test.Benchmark):
"""Base class for dataset benchmarks.""" """Base class for dataset benchmarks."""
def _run_eager_benchmark(self, iterable, iters, warmup): def _run_eager_benchmark(self, iterable, iters, warmup):
"""Benchmark in eager mode. """Benchmark the iterable in eager mode.
Runs the op `iters` times. In each iteration, the benchmark measures Runs the iterable `iters` times. In each iteration, the benchmark measures
the time it takes to go execute the op. the time it takes to go execute the iterable.
Args: Args:
iterable: The tf op or tf.data Dataset to benchmark. iterable: The tf op or tf.data Dataset to benchmark.
@ -44,7 +44,7 @@ class DatasetBenchmarkBase(test.Benchmark):
Returns: Returns:
A float, representing the median time (with respect to `iters`) A float, representing the median time (with respect to `iters`)
it takes for the op to be executed `iters` num of times. it takes for the iterable to be executed `iters` num of times.
""" """
deltas = [] deltas = []
@ -69,11 +69,11 @@ class DatasetBenchmarkBase(test.Benchmark):
iters, iters,
warmup, warmup,
session_config, session_config,
is_dataset=True): initializer=None):
"""Benchmarks the op in graph mode. """Benchmarks the iterable in graph mode.
Runs the op `iters` times. In each iteration, the benchmark measures Runs the iterable `iters` times. In each iteration, the benchmark measures
the time it takes to go execute the op. the time it takes to go execute the iterable.
Args: Args:
iterable: The tf op or tf.data Dataset to benchmark. iterable: The tf op or tf.data Dataset to benchmark.
@ -81,12 +81,12 @@ class DatasetBenchmarkBase(test.Benchmark):
warmup: If true, warms up the session caches by running an untimed run. warmup: If true, warms up the session caches by running an untimed run.
session_config: A ConfigProto protocol buffer with configuration options session_config: A ConfigProto protocol buffer with configuration options
for the session. Applicable only for benchmarking in graph mode. for the session. Applicable only for benchmarking in graph mode.
is_dataset: A boolean value representing whether the op is a tf.data initializer: The initializer op required to initialize the
Dataset or not. iterable.
Returns: Returns:
A float, representing the median time (with respect to `iters`) A float, representing the median time (with respect to `iters`)
it takes for the op to be executed `iters` num of times. it takes for the iterable to be executed `iters` num of times.
""" """
deltas = [] deltas = []
@ -94,25 +94,18 @@ class DatasetBenchmarkBase(test.Benchmark):
raise RuntimeError( raise RuntimeError(
"Graph mode benchmarking is not supported in eager mode.") "Graph mode benchmarking is not supported in eager mode.")
if is_dataset:
iterator = dataset_ops.make_initializable_iterator(iterable)
next_element = iterator.get_next()
op = nest.flatten(next_element)[0].op
else:
op = iterable
for _ in range(iters): for _ in range(iters):
with session.Session(config=session_config) as sess: with session.Session(config=session_config) as sess:
if warmup: if warmup:
if is_dataset:
sess.run(iterator.initializer)
# Run once to warm up the session caches. # Run once to warm up the session caches.
sess.run(op) if initializer:
sess.run(initializer)
sess.run(iterable)
if is_dataset: if initializer:
sess.run(iterator.initializer) sess.run(initializer)
start = time.time() start = time.time()
sess.run(op) sess.run(iterable)
end = time.time() end = time.time()
deltas.append(end - start) deltas.append(end - start)
return np.median(deltas) return np.median(deltas)
@ -151,8 +144,7 @@ class DatasetBenchmarkBase(test.Benchmark):
iterable=op, iterable=op,
iters=iters, iters=iters,
warmup=warmup, warmup=warmup,
session_config=session_config, session_config=session_config
is_dataset=False
) )
def run_benchmark(self, def run_benchmark(self,
@ -206,11 +198,15 @@ class DatasetBenchmarkBase(test.Benchmark):
) )
return median_duration / float(num_elements) return median_duration / float(num_elements)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
op = nest.flatten(next_element)[0].op
median_duration = self._run_graph_benchmark( median_duration = self._run_graph_benchmark(
iterable=dataset, iterable=op,
iters=iters, iters=iters,
warmup=warmup, warmup=warmup,
session_config=session_config session_config=session_config,
initializer=iterator.initializer
) )
return median_duration / float(num_elements) return median_duration / float(num_elements)