Add eager overhead benchmarks for more Layers.

Establishes a consistent convention for naming and ordering these benchmarks:

benchmark_layers_{module_name}_{class_name}_overhead

PiperOrigin-RevId: 316155083
Change-Id: I336b8881879b3ae216896959c0e504a8acb298c6
This commit is contained in:
Thomas O'Malley 2020-06-12 12:29:20 -07:00 committed by TensorFlower Gardener
parent 425ef02cbb
commit 798f7515f9

View File

@ -22,8 +22,11 @@ import time
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.layers import convolutional as conv_layers
from tensorflow.python.keras.layers import core as core_layers
from tensorflow.python.keras.layers import advanced_activations
from tensorflow.python.keras.layers import convolutional
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers import embeddings
from tensorflow.python.keras.layers import normalization
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
@ -87,7 +90,7 @@ class MicroBenchmarksBase(test.Benchmark):
def _run(self, func, num_iters, execution_mode=None):
self.run_report(_run_benchmark, func, num_iters, execution_mode)
def benchmark_tf_keras_layer_call_overhead(self):
def benchmark_layers_call_overhead(self):
class OnlyOverheadLayer(base_layer.Layer):
@ -102,19 +105,90 @@ class MicroBenchmarksBase(test.Benchmark):
self._run(fn, 10000)
def benchmark_tf_keras_dense_overhead(self):
# Naming convention: benchmark_layers_{module_name}_{class}_overhead.
def benchmark_layers_advanced_activations_leaky_relu_overhead(self):
layer = core_layers.Dense(1)
x = ops.convert_to_tensor([[1.]])
layer = advanced_activations.LeakyReLU()
x = array_ops.ones((1, 1))
def fn():
layer(x)
self._run(fn, 10000)
def benchmark_tf_keras_flatten_overhead(self):
def benchmark_layers_advanced_activations_prelu_overhead(self):
layer = core_layers.Flatten()
layer = advanced_activations.PReLU()
x = array_ops.ones((1, 1))
def fn():
layer(x)
self._run(fn, 10000)
def benchmark_layers_advanced_activations_elu_overhead(self):
layer = advanced_activations.ELU()
x = array_ops.ones((1, 1))
def fn():
layer(x)
self._run(fn, 10000)
def benchmark_layers_advanced_activations_thresholded_relu_overhead(self):
layer = advanced_activations.ThresholdedReLU()
x = array_ops.ones((1, 1))
def fn():
layer(x)
self._run(fn, 10000)
def benchmark_layers_advanced_activations_softmax_overhead(self):
layer = advanced_activations.Softmax()
x = array_ops.ones((1, 1))
def fn():
layer(x)
self._run(fn, 10000)
def benchmark_layers_advanced_activations_relu_overhead(self):
layer = advanced_activations.ReLU()
x = array_ops.ones((1, 1))
def fn():
layer(x)
self._run(fn, 10000)
def benchmark_layers_core_masking_overhead(self):
layer = core.Masking()
x = array_ops.ones((1, 1))
def fn():
layer(x)
self._run(fn, 10000)
def benchmark_layers_core_dropout_overhead(self):
layer = core.Dropout(0.5)
x = array_ops.ones((1, 1))
def fn():
layer(x, training=True)
self._run(fn, 10000)
def benchmark_layers_core_flatten_overhead(self):
layer = core.Flatten()
x = ops.convert_to_tensor([[[1.]]])
def fn():
@ -122,9 +196,19 @@ class MicroBenchmarksBase(test.Benchmark):
self._run(fn, 10000)
def benchmark_tf_keras_conv1d_overhead(self):
def benchmark_layers_core_dense_overhead(self):
layer = conv_layers.Conv1D(1, (1,))
layer = core.Dense(1)
x = ops.convert_to_tensor([[1.]])
def fn():
layer(x)
self._run(fn, 10000)
def benchmark_layers_convolutional_conv1d_overhead(self):
layer = convolutional.Conv1D(1, (1,))
x = array_ops.ones((1, 1, 1))
def fn():
@ -132,9 +216,9 @@ class MicroBenchmarksBase(test.Benchmark):
self._run(fn, 10000)
def benchmark_tf_keras_conv2d_overhead(self):
def benchmark_layers_convolutional_conv2d_overhead(self):
layer = conv_layers.Conv2D(1, (1, 1))
layer = convolutional.Conv2D(1, (1, 1))
x = array_ops.ones((1, 1, 1, 1))
def fn():
@ -142,9 +226,9 @@ class MicroBenchmarksBase(test.Benchmark):
self._run(fn, 10000)
def benchmark_tf_keras_conv3d_overhead(self):
def benchmark_layers_convolutional_conv3d_overhead(self):
layer = conv_layers.Conv3D(1, (1, 1, 1))
layer = convolutional.Conv3D(1, (1, 1, 1))
x = array_ops.ones((1, 1, 1, 1, 1))
def fn():
@ -152,6 +236,36 @@ class MicroBenchmarksBase(test.Benchmark):
self._run(fn, 10000)
def benchmark_layers_embeddings_embedding_overhead(self):
layer = embeddings.Embedding(1, 1)
x = array_ops.zeros((1, 1), dtype="int32")
def fn():
layer(x)
self._run(fn, 10000)
def benchmark_layers_normalization_batch_normalization_overhead(self):
layer = normalization.BatchNormalization()
x = array_ops.ones((1, 1))
def fn():
layer(x, training=True)
self._run(fn, 10000)
def benchmark_layers_normalization_layer_normalization_overhead(self):
layer = normalization.LayerNormalization()
x = array_ops.ones((1, 1))
def fn():
layer(x, training=True)
self._run(fn, 10000)
if __name__ == "__main__":
ops.enable_eager_execution()