Create benchmarks for Conv2D and LSTM layers.
PiperOrigin-RevId: 338385224 Change-Id: I53fa885addc52c70109ff1ad44fc974d259a60b1
This commit is contained in:
parent
fb22dff317
commit
9d266e05ac
@ -9,6 +9,7 @@ tensorflow/lite/micro/build_def.bzl
|
||||
tensorflow/python/autograph/core/config.py
|
||||
tensorflow/python/eager/benchmarks_test_base.py
|
||||
tensorflow/python/framework/tfrt_utils.py
|
||||
tensorflow/python/keras/benchmarks/layer_benchmarks/run_xprof.py
|
||||
tensorflow/python/tpu/profiler/pip_package/BUILD
|
||||
tensorflow/python/tpu/profiler/pip_package/README
|
||||
tensorflow/python/tpu/profiler/pip_package/build_pip_package.sh
|
||||
|
66
tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD
Normal file
66
tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD
Normal file
@ -0,0 +1,66 @@
|
||||
# Description:
|
||||
# Implementation of benchmarks on Keras layers.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_py_srcs",
|
||||
srcs = glob(["*.py"]),
|
||||
visibility = ["//tensorflow/python/keras/google/private_tf_api_test:__pkg__"],
|
||||
)
|
||||
|
||||
BECHMARK_TAGS = [
|
||||
"no_oss_py38", # TODO(b/162044699)
|
||||
"no_pip", # TODO(b/161253163)
|
||||
"no_windows", # TODO(b/160628318)
|
||||
]
|
||||
|
||||
# To run CPU benchmarks:
|
||||
# bazel run -c opt benchmarks_test -- --benchmarks=.
|
||||
|
||||
# To run GPU benchmarks:
|
||||
# bazel run -c opt --config=cuda benchmarks_test -- \
|
||||
# --benchmarks=.
|
||||
|
||||
# To run benchmarks with TFRT:
|
||||
# bazel run -c opt --config=cuda --test_env=EXPERIMENTAL_ENABLE_TFRT=1 benchmarks_test -- \
|
||||
# --benchmarks=.
|
||||
|
||||
# To run a subset of benchmarks using --benchmarks flag.
|
||||
# --benchmarks: the list of benchmarks to run. The specified value is interpreted
|
||||
# as a regular expression and any benchmark whose name contains a partial match
|
||||
# to the regular expression is executed.
|
||||
# e.g. --benchmarks=".*lstm*." will run all lstm layer related benchmarks.
|
||||
|
||||
py_library(
|
||||
name = "run_xprof",
|
||||
srcs = ["run_xprof.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "layer_benchmarks_test_base",
|
||||
srcs = ["layer_benchmarks_test_base.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":run_xprof",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python/keras/benchmarks:profiler_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "layer_benchmarks_test",
|
||||
srcs = ["layer_benchmarks_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = BECHMARK_TAGS,
|
||||
deps = [
|
||||
":layer_benchmarks_test_base",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
@ -0,0 +1,128 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Benchmarks on Keras layers."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import six
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras.benchmarks.layer_benchmarks import layer_benchmarks_test_base
|
||||
from tensorflow.python.platform import benchmark
|
||||
|
||||
|
||||
def _layer_call_backward(layer, x):
|
||||
with tf.GradientTape() as tape:
|
||||
y = layer(x)
|
||||
loss = tf.reduce_mean(y**2)
|
||||
|
||||
_ = tape.gradient(loss, layer.trainable_variables)
|
||||
|
||||
|
||||
class KerasLayerBenchmarks(six.with_metaclass(
|
||||
benchmark.ParameterizedBenchmark,
|
||||
layer_benchmarks_test_base.LayerBenchmarksBase)):
|
||||
|
||||
_benchmark_parameters = [
|
||||
("Conv2D_small_shape", tf.keras.layers.Conv2D,
|
||||
{"filters": 1, "kernel_size": 1, "activation": "relu"},
|
||||
(1, 1, 1, 1), 10000),
|
||||
("Conv2D_normal_shape", tf.keras.layers.Conv2D,
|
||||
{"filters": 1, "kernel_size": 1, "activation": "relu"},
|
||||
(64, 28, 28, 3), 10000),
|
||||
("LSTM_small_shape", tf.keras.layers.LSTM,
|
||||
{"units": 1}, (1, 1, 1), 10000),
|
||||
("LSTM_normal_shape", tf.keras.layers.LSTM,
|
||||
{"units": 4}, (32, 10, 8), 10000),
|
||||
]
|
||||
|
||||
def benchmark_layer_call(self, layer_cls, layer_args, input_shape, num_iters):
|
||||
layer = layer_cls(**layer_args)
|
||||
x = tf.ones(input_shape)
|
||||
|
||||
fn = functools.partial(layer, x)
|
||||
self.run_report(fn, num_iters)
|
||||
|
||||
def benchmark_layer_call_with_function(
|
||||
self, layer_cls, layer_args, input_shape, num_iters):
|
||||
layer = layer_cls(**layer_args)
|
||||
x = tf.ones(input_shape)
|
||||
layer.call = tf.function(layer.call)
|
||||
|
||||
fn = functools.partial(layer, x)
|
||||
self.run_report(fn, num_iters)
|
||||
|
||||
def benchmark_layer_call_with_xla(
|
||||
self, layer_cls, layer_args, input_shape, num_iters):
|
||||
layer = layer_cls(**layer_args)
|
||||
x = tf.ones(input_shape)
|
||||
layer.call = tf.function(
|
||||
layer.call, experimental_compile=True)
|
||||
|
||||
fn = functools.partial(layer, x)
|
||||
self.run_report(fn, num_iters)
|
||||
|
||||
def benchmark_layer_call_backward(
|
||||
self, layer_cls, layer_args, input_shape, num_iters):
|
||||
layer = layer_cls(**layer_args)
|
||||
x = tf.ones(input_shape)
|
||||
|
||||
fn = functools.partial(_layer_call_backward, layer, x)
|
||||
self.run_report(fn, num_iters)
|
||||
|
||||
def benchmark_layer_call_backward_with_function(
|
||||
self, layer_cls, layer_args, input_shape, num_iters):
|
||||
layer = layer_cls(**layer_args)
|
||||
x = tf.ones(input_shape)
|
||||
layer.call = tf.function(layer.call)
|
||||
|
||||
fn = functools.partial(_layer_call_backward, layer, x)
|
||||
self.run_report(fn, num_iters)
|
||||
|
||||
|
||||
class KerasLayerBenchmarksBackwardXLA(six.with_metaclass(
|
||||
benchmark.ParameterizedBenchmark,
|
||||
layer_benchmarks_test_base.LayerBenchmarksBase)):
|
||||
|
||||
_benchmark_parameters = [
|
||||
("Conv2D_small_shape", tf.keras.layers.Conv2D,
|
||||
{"filters": 1, "kernel_size": 1, "activation": "relu"},
|
||||
(1, 1, 1, 1), 10000),
|
||||
("Conv2D_normal_shape", tf.keras.layers.Conv2D,
|
||||
{"filters": 1, "kernel_size": 1, "activation": "relu"},
|
||||
(64, 28, 28, 3), 10000),
|
||||
# TODO(b/153480400)
|
||||
# ("LSTM_small_shape", tf.keras.layers.LSTM,
|
||||
# {"units": 1}, (1, 1, 1), 10000),
|
||||
# ("LSTM_normal_shape", tf.keras.layers.LSTM,
|
||||
# {"units": 4}, (32, 10, 8), 10000),
|
||||
]
|
||||
|
||||
def benchmark_layer_call_backward_with_xla(
|
||||
self, layer_cls, layer_args, input_shape, num_iters):
|
||||
layer = layer_cls(**layer_args)
|
||||
x = tf.ones(input_shape)
|
||||
layer.call = tf.function(
|
||||
layer.call, experimental_compile=True)
|
||||
|
||||
fn = functools.partial(_layer_call_backward, layer, x)
|
||||
self.run_report(fn, num_iters)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -0,0 +1,72 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Benchmark base to run and report Keras layers benchmark results."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.keras.benchmarks.layer_benchmarks import run_xprof
|
||||
|
||||
|
||||
class LayerBenchmarksBase(tf.test.Benchmark):
|
||||
"""Run and report benchmark results.
|
||||
|
||||
The first run is without any profiling to purly measure running time.
|
||||
Second run is with xprof but no python trace.
|
||||
Third run is with xprof and python trace.
|
||||
Note: xprof runs fewer iterations, and the maximum iterations is 100.
|
||||
"""
|
||||
|
||||
def run_report(self, func, num_iters):
|
||||
"""Run and report benchmark results for different settings."""
|
||||
|
||||
# 0. Warm up.
|
||||
func()
|
||||
|
||||
# 1. Run without profiling.
|
||||
start = time.time()
|
||||
for _ in range(num_iters):
|
||||
func()
|
||||
total_time = time.time() - start
|
||||
us_mean_time = total_time * 1e6 / num_iters
|
||||
|
||||
metrics = [
|
||||
{"name": "examples_per_sec",
|
||||
"value": float("{0:.3f}".format(num_iters / total_time))},
|
||||
{"name": "us_per_example",
|
||||
"value": float("{0:.3f}".format(us_mean_time))}]
|
||||
|
||||
# 2. Run with xprof with no python trace.
|
||||
num_iters_xprof = min(100, num_iters)
|
||||
xprof_link, us_per_example = run_xprof.run_with_xprof(
|
||||
func, num_iters_xprof, False)
|
||||
# This xprof link will appear in the benchmark dashboard.
|
||||
extras = {
|
||||
"xprof_link": xprof_link,
|
||||
"us_per_example_with_xprof": us_per_example
|
||||
}
|
||||
|
||||
# 3. Run with xprof and python trace.
|
||||
xprof_link, us_per_example = run_xprof.run_with_xprof(
|
||||
func, num_iters_xprof, True)
|
||||
extras["xprof_with_python_trace"] = xprof_link
|
||||
extras["us_per_example_with_xprof_and_python"] = us_per_example
|
||||
|
||||
self.report_benchmark(
|
||||
iters=num_iters, wall_time=us_mean_time, extras=extras, metrics=metrics)
|
@ -0,0 +1,40 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import as _absolute_import
|
||||
from __future__ import division as _division
|
||||
from __future__ import print_function as _print_function
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from tensorflow.python.profiler import profiler_v2 as profiler
|
||||
|
||||
def run_with_xprof(self, func, num_iters_xprof=100, enable_python_trace=True,
|
||||
logdir='/tmp/layer_benchmark_xprof/'):
|
||||
suid = str(uuid.uuid4())
|
||||
if enable_python_trace:
|
||||
options = profiler.ProfilerOptions(python_tracer_level=1)
|
||||
logdir = os.path.join(logdir, str(uuid.uuid4()) + "_with_python")
|
||||
else:
|
||||
options = profiler.ProfilerOptions(python_tracer_level=0)
|
||||
logdir = os.path.join(logdir, suid)
|
||||
|
||||
start = time.time()
|
||||
with profiler.Profile(logdir, options):
|
||||
for _ in range(num_iters_xprof):
|
||||
func()
|
||||
total_time = time.time() - start
|
||||
us_per_example = float("{0:.3f}".format(total_time * 1e6 / num_iters_xprof))
|
||||
return logdir, us_per_example
|
Loading…
Reference in New Issue
Block a user