From 4a3d5de03e3b98765177a154de0f52cabf9b8fa6 Mon Sep 17 00:00:00 2001 From: Vignesh Kothapalli Date: Wed, 10 Feb 2021 00:51:24 +0530 Subject: [PATCH] use initializer op for resetting the iterator --- .../python/data/benchmarks/benchmark_base.py | 52 +++++++++---------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/tensorflow/python/data/benchmarks/benchmark_base.py b/tensorflow/python/data/benchmarks/benchmark_base.py index e241e8a87d9..636347fc6aa 100644 --- a/tensorflow/python/data/benchmarks/benchmark_base.py +++ b/tensorflow/python/data/benchmarks/benchmark_base.py @@ -32,10 +32,10 @@ class DatasetBenchmarkBase(test.Benchmark): """Base class for dataset benchmarks.""" def _run_eager_benchmark(self, iterable, iters, warmup): - """Benchmark in eager mode. + """Benchmark the iterable in eager mode. - Runs the op `iters` times. In each iteration, the benchmark measures - the time it takes to go execute the op. + Runs the iterable `iters` times. In each iteration, the benchmark measures + the time it takes to go execute the iterable. Args: iterable: The tf op or tf.data Dataset to benchmark. @@ -44,7 +44,7 @@ class DatasetBenchmarkBase(test.Benchmark): Returns: A float, representing the median time (with respect to `iters`) - it takes for the op to be executed `iters` num of times. + it takes for the iterable to be executed `iters` num of times. """ deltas = [] @@ -69,11 +69,11 @@ class DatasetBenchmarkBase(test.Benchmark): iters, warmup, session_config, - is_dataset=True): - """Benchmarks the op in graph mode. + initializer=None): + """Benchmarks the iterable in graph mode. - Runs the op `iters` times. In each iteration, the benchmark measures - the time it takes to go execute the op. + Runs the iterable `iters` times. In each iteration, the benchmark measures + the time it takes to go execute the iterable. Args: iterable: The tf op or tf.data Dataset to benchmark. @@ -81,12 +81,12 @@ class DatasetBenchmarkBase(test.Benchmark): warmup: If true, warms up the session caches by running an untimed run. session_config: A ConfigProto protocol buffer with configuration options for the session. Applicable only for benchmarking in graph mode. - is_dataset: A boolean value representing whether the op is a tf.data - Dataset or not. + initializer: The initializer op required to initialize the + iterable. Returns: A float, representing the median time (with respect to `iters`) - it takes for the op to be executed `iters` num of times. + it takes for the iterable to be executed `iters` num of times. """ deltas = [] @@ -94,25 +94,18 @@ class DatasetBenchmarkBase(test.Benchmark): raise RuntimeError( "Graph mode benchmarking is not supported in eager mode.") - if is_dataset: - iterator = dataset_ops.make_initializable_iterator(iterable) - next_element = iterator.get_next() - op = nest.flatten(next_element)[0].op - else: - op = iterable - for _ in range(iters): with session.Session(config=session_config) as sess: if warmup: - if is_dataset: - sess.run(iterator.initializer) # Run once to warm up the session caches. - sess.run(op) + if initializer: + sess.run(initializer) + sess.run(iterable) - if is_dataset: - sess.run(iterator.initializer) + if initializer: + sess.run(initializer) start = time.time() - sess.run(op) + sess.run(iterable) end = time.time() deltas.append(end - start) return np.median(deltas) @@ -151,8 +144,7 @@ class DatasetBenchmarkBase(test.Benchmark): iterable=op, iters=iters, warmup=warmup, - session_config=session_config, - is_dataset=False + session_config=session_config ) def run_benchmark(self, @@ -206,11 +198,15 @@ class DatasetBenchmarkBase(test.Benchmark): ) return median_duration / float(num_elements) + iterator = dataset_ops.make_initializable_iterator(dataset) + next_element = iterator.get_next() + op = nest.flatten(next_element)[0].op median_duration = self._run_graph_benchmark( - iterable=dataset, + iterable=op, iters=iters, warmup=warmup, - session_config=session_config + session_config=session_config, + initializer=iterator.initializer ) return median_duration / float(num_elements)