From f8066f9dc460c2dddf696a1d1bb879515b42421b Mon Sep 17 00:00:00 2001 From: Kibeom Kim Date: Fri, 21 Aug 2020 16:26:18 -0700 Subject: [PATCH] Add tf.core end-to-end KPI benchmarks. Add dedicated tf.core end-to-end KPI benchmarks for the following reasons: - Most key tf.core API's execution time depends on other factors like input size, etc,... - End-to-end time is important as there are overheads that's not caught by internal timing measuring e.g. b/158246276 PiperOrigin-RevId: 327893393 Change-Id: Ic01f98d98a8edc9e19f3fad64804abe916d4aee0 --- tensorflow/python/eager/benchmarks/BUILD | 21 +++ .../eager/benchmarks/kpi_benchmark_test.py | 121 ++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 tensorflow/python/eager/benchmarks/BUILD create mode 100644 tensorflow/python/eager/benchmarks/kpi_benchmark_test.py diff --git a/tensorflow/python/eager/benchmarks/BUILD b/tensorflow/python/eager/benchmarks/BUILD new file mode 100644 index 00000000000..8e147d50d9e --- /dev/null +++ b/tensorflow/python/eager/benchmarks/BUILD @@ -0,0 +1,21 @@ +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) + +cuda_py_test( + name = "kpi_benchmark_test", + size = "medium", + srcs = ["kpi_benchmark_test.py"], + python_version = "PY3", + tags = [ + "no_windows", # b/141617449 + "optonly", + ], + deps = [ + "//tensorflow:tensorflow_py_no_contrib", + "//tensorflow/python/eager:benchmarks_test_base", + ], +) diff --git a/tensorflow/python/eager/benchmarks/kpi_benchmark_test.py b/tensorflow/python/eager/benchmarks/kpi_benchmark_test.py new file mode 100644 index 00000000000..22a70e199f9 --- /dev/null +++ b/tensorflow/python/eager/benchmarks/kpi_benchmark_test.py @@ -0,0 +1,121 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""KPI Benchmarks for low-level eager execution primitives. + +This is a suite of full end-to-end integration benchmakr for low-level eager +execution APIs. Also tracks them as KPI Traceme. + +To run CPU benchmarks: + bazel run -c opt kpi_benchmarks_test -- --benchmarks=. + +To run GPU benchmarks: + bazel run --config=cuda -c opt --copt="-mavx" kpi_benchmarks_test -- \ + --benchmarks=. + +To run a subset of benchmarks using --benchmarks flag. +--benchmarks: the list of benchmarks to run. The specified value is interpreted +as a regular expression and any benchmark whose name contains a partial match +to the regular expression is executed. +e.g. --benchmarks=".*matmul*." will run all matmul related benchmarks. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc +import time + +import tensorflow as tf + +from tensorflow.python.eager import benchmarks_test_base +from tensorflow.python.eager import context +from tensorflow.python.profiler import trace + +NUM_ITERATIONS = 30000 + + +def _run_benchmark(func, num_iters, execution_mode=None): + ctx = context.context() + with context.execution_mode(execution_mode): + # call func to warm up + func() + if execution_mode == context.ASYNC: + ctx.executor.wait() + start = time.time() + for _ in range(num_iters): + func() + if execution_mode == context.ASYNC: + ctx.executor.wait() + end = time.time() + + return end - start + + +class KpiBenchmarks(benchmarks_test_base.MicroBenchmarksBase): + """A Collection of KPI benchmarks.""" + + def _get_benchmark_name(self): + return self._get_name() + + def _run(self, func, num_iters): + gc.disable() + gc.collect() + self.run_report(_run_benchmark, func, num_iters) + gc.enable() + + def benchmark_tf_constant_2x2(self): + x = [[1., 2.], [3., 4.]] + + def fn(): + with trace.Trace("tf.constant-2x2"): + tf.constant(x) + + self._run(fn, NUM_ITERATIONS) + + def benchmark_tf_convert_to_tensor_2x2(self): + x = [[1., 2.], [3., 4.]] + + def fn(): + with trace.Trace("tf.convert_to_tensor-2x2"): + tf.convert_to_tensor(x) + + self._run(fn, NUM_ITERATIONS) + + def benchmark_tf_nn_relu_2x2(self): + x = tf.constant([[1., 2.], [3., 4.]]) + + def fn(): + with trace.Trace("tf.nn.relu-2x2"): + tf.nn.relu(x) + + self._run(fn, NUM_ITERATIONS) + + def benchmark_tf_function_invocation_identity(self): + x = tf.constant([[1., 2.], [3., 4.]]) + + @tf.function + def identity(x): + return x + + def fn(): + with trace.Trace("tf.function-identity"): + identity(x) + + self._run(fn, NUM_ITERATIONS) + + +if __name__ == "__main__": + tf.test.main()