rename op_or_dataset to iterable

This commit is contained in:
Vignesh Kothapalli 2021-02-10 00:06:35 +05:30
parent bbbdd5327d
commit f23830c990
No known key found for this signature in database
GPG Key ID: 7C9D6956FBA21DD7

View File

@ -31,14 +31,14 @@ from tensorflow.python.platform import test
class DatasetBenchmarkBase(test.Benchmark):
"""Base class for dataset benchmarks."""
def _run_eager_benchmark(self, op_or_dataset, iters, warmup):
def _run_eager_benchmark(self, iterable, iters, warmup):
"""Benchmark in eager mode.
Runs the op `iters` times. In each iteration, the benchmark measures
the time it takes to go execute the op.
Args:
op_or_dataset: The tf op or tf.data Dataset to benchmark.
iterable: The tf op or tf.data Dataset to benchmark.
iters: Number of times to repeat the timing.
warmup: If true, warms up the session caches by running an untimed run.
@ -54,10 +54,10 @@ class DatasetBenchmarkBase(test.Benchmark):
for _ in range(iters):
if warmup:
iterator = iter(op_or_dataset)
iterator = iter(iterable)
next(iterator)
iterator = iter(op_or_dataset)
iterator = iter(iterable)
start = time.time()
next(iterator)
end = time.time()
@ -65,7 +65,7 @@ class DatasetBenchmarkBase(test.Benchmark):
return np.median(deltas)
def _run_graph_benchmark(self,
op_or_dataset,
iterable,
iters,
warmup,
session_config,
@ -76,7 +76,7 @@ class DatasetBenchmarkBase(test.Benchmark):
the time it takes to go execute the op.
Args:
op_or_dataset: The tf op or tf.data Dataset to benchmark.
iterable: The tf op or tf.data Dataset to benchmark.
iters: Number of times to repeat the timing.
warmup: If true, warms up the session caches by running an untimed run.
session_config: A ConfigProto protocol buffer with configuration options
@ -95,11 +95,11 @@ class DatasetBenchmarkBase(test.Benchmark):
"Graph mode benchmarking is not supported in eager mode.")
if is_dataset:
iterator = dataset_ops.make_initializable_iterator(dataset)
iterator = dataset_ops.make_initializable_iterator(iterable)
next_element = iterator.get_next()
op = nest.flatten(next_element)[0].op
else:
op = op_or_dataset
op = iterable
for _ in range(iters):
with session.Session(config=session_config) as sess:
@ -142,13 +142,13 @@ class DatasetBenchmarkBase(test.Benchmark):
if context.executing_eagerly():
return self._run_eager_benchmark(
op_or_dataset=op,
iterable=op,
iters=iters,
warmup=warmup
)
return self._run_graph_benchmark(
op_or_dataset=op,
iterable=op,
iters=iters,
warmup=warmup,
session_config=session_config,
@ -200,14 +200,14 @@ class DatasetBenchmarkBase(test.Benchmark):
if context.executing_eagerly():
median_duration = self._run_eager_benchmark(
op_or_dataset=dataset,
iterable=dataset,
iters=iters,
warmup=warmup
)
return median_duration / float(num_elements)
median_duration = self._run_graph_benchmark(
op_or_dataset=dataset,
iterable=dataset,
iters=iters,
warmup=warmup,
session_config=session_config