From 9d266e05ace872daaa0cf46cdb7a84f9b1591248 Mon Sep 17 00:00:00 2001
From: Yanhui Liang <yhliang@google.com>
Date: Wed, 21 Oct 2020 19:02:39 -0700
Subject: [PATCH] Create benchmarks for Conv2D and LSTM layers.

PiperOrigin-RevId: 338385224
Change-Id: I53fa885addc52c70109ff1ad44fc974d259a60b1
---
 tensorflow/opensource_only.files              |   1 +
 .../keras/benchmarks/layer_benchmarks/BUILD   |  66 +++++++++
 .../layer_benchmarks/layer_benchmarks_test.py | 128 ++++++++++++++++++
 .../layer_benchmarks_test_base.py             |  72 ++++++++++
 .../benchmarks/layer_benchmarks/run_xprof.py  |  40 ++++++
 5 files changed, 307 insertions(+)
 create mode 100644 tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD
 create mode 100644 tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py
 create mode 100644 tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test_base.py
 create mode 100644 tensorflow/python/keras/benchmarks/layer_benchmarks/run_xprof.py

diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files
index ad02ead9e03..0afe10b825a 100644
--- a/tensorflow/opensource_only.files
+++ b/tensorflow/opensource_only.files
@@ -9,6 +9,7 @@ tensorflow/lite/micro/build_def.bzl
 tensorflow/python/autograph/core/config.py
 tensorflow/python/eager/benchmarks_test_base.py
 tensorflow/python/framework/tfrt_utils.py
+tensorflow/python/keras/benchmarks/layer_benchmarks/run_xprof.py
 tensorflow/python/tpu/profiler/pip_package/BUILD
 tensorflow/python/tpu/profiler/pip_package/README
 tensorflow/python/tpu/profiler/pip_package/build_pip_package.sh
diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD b/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD
new file mode 100644
index 00000000000..7c3b55c02bd
--- /dev/null
+++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD
@@ -0,0 +1,66 @@
+# Description:
+#   Implementation of benchmarks on Keras layers.
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+filegroup(
+    name = "all_py_srcs",
+    srcs = glob(["*.py"]),
+    visibility = ["//tensorflow/python/keras/google/private_tf_api_test:__pkg__"],
+)
+
+BECHMARK_TAGS = [
+    "no_oss_py38",  # TODO(b/162044699)
+    "no_pip",  # TODO(b/161253163)
+    "no_windows",  # TODO(b/160628318)
+]
+
+# To run CPU benchmarks:
+#   bazel run -c opt benchmarks_test -- --benchmarks=.
+
+# To run GPU benchmarks:
+#   bazel run -c opt --config=cuda benchmarks_test -- \
+#     --benchmarks=.
+
+# To run benchmarks with TFRT:
+#   bazel run -c opt --config=cuda --test_env=EXPERIMENTAL_ENABLE_TFRT=1 benchmarks_test -- \
+#     --benchmarks=.
+
+# To run a subset of benchmarks using --benchmarks flag.
+# --benchmarks: the list of benchmarks to run. The specified value is interpreted
+# as a regular expression and any benchmark whose name contains a partial match
+# to the regular expression is executed.
+# e.g. --benchmarks=".*lstm*." will run all lstm layer related benchmarks.
+
+py_library(
+    name = "run_xprof",
+    srcs = ["run_xprof.py"],
+    visibility = ["//tensorflow:internal"],
+)
+
+py_library(
+    name = "layer_benchmarks_test_base",
+    srcs = ["layer_benchmarks_test_base.py"],
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        ":run_xprof",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/benchmarks:profiler_lib",
+    ],
+)
+
+tf_py_test(
+    name = "layer_benchmarks_test",
+    srcs = ["layer_benchmarks_test.py"],
+    python_version = "PY3",
+    tags = BECHMARK_TAGS,
+    deps = [
+        ":layer_benchmarks_test_base",
+        "//tensorflow:tensorflow_py",
+    ],
+)
diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py
new file mode 100644
index 00000000000..57f2b18e982
--- /dev/null
+++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py
@@ -0,0 +1,128 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks on Keras layers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import six
+
+import tensorflow as tf
+from tensorflow.python.keras.benchmarks.layer_benchmarks import layer_benchmarks_test_base
+from tensorflow.python.platform import benchmark
+
+
+def _layer_call_backward(layer, x):
+  with tf.GradientTape() as tape:
+    y = layer(x)
+    loss = tf.reduce_mean(y**2)
+
+  _ = tape.gradient(loss, layer.trainable_variables)
+
+
+class KerasLayerBenchmarks(six.with_metaclass(
+    benchmark.ParameterizedBenchmark,
+    layer_benchmarks_test_base.LayerBenchmarksBase)):
+
+  _benchmark_parameters = [
+      ("Conv2D_small_shape", tf.keras.layers.Conv2D,
+       {"filters": 1, "kernel_size": 1, "activation": "relu"},
+       (1, 1, 1, 1), 10000),
+      ("Conv2D_normal_shape", tf.keras.layers.Conv2D,
+       {"filters": 1, "kernel_size": 1, "activation": "relu"},
+       (64, 28, 28, 3), 10000),
+      ("LSTM_small_shape", tf.keras.layers.LSTM,
+       {"units": 1}, (1, 1, 1), 10000),
+      ("LSTM_normal_shape", tf.keras.layers.LSTM,
+       {"units": 4}, (32, 10, 8), 10000),
+  ]
+
+  def benchmark_layer_call(self, layer_cls, layer_args, input_shape, num_iters):
+    layer = layer_cls(**layer_args)
+    x = tf.ones(input_shape)
+
+    fn = functools.partial(layer, x)
+    self.run_report(fn, num_iters)
+
+  def benchmark_layer_call_with_function(
+      self, layer_cls, layer_args, input_shape, num_iters):
+    layer = layer_cls(**layer_args)
+    x = tf.ones(input_shape)
+    layer.call = tf.function(layer.call)
+
+    fn = functools.partial(layer, x)
+    self.run_report(fn, num_iters)
+
+  def benchmark_layer_call_with_xla(
+      self, layer_cls, layer_args, input_shape, num_iters):
+    layer = layer_cls(**layer_args)
+    x = tf.ones(input_shape)
+    layer.call = tf.function(
+        layer.call, experimental_compile=True)
+
+    fn = functools.partial(layer, x)
+    self.run_report(fn, num_iters)
+
+  def benchmark_layer_call_backward(
+      self, layer_cls, layer_args, input_shape, num_iters):
+    layer = layer_cls(**layer_args)
+    x = tf.ones(input_shape)
+
+    fn = functools.partial(_layer_call_backward, layer, x)
+    self.run_report(fn, num_iters)
+
+  def benchmark_layer_call_backward_with_function(
+      self, layer_cls, layer_args, input_shape, num_iters):
+    layer = layer_cls(**layer_args)
+    x = tf.ones(input_shape)
+    layer.call = tf.function(layer.call)
+
+    fn = functools.partial(_layer_call_backward, layer, x)
+    self.run_report(fn, num_iters)
+
+
+class KerasLayerBenchmarksBackwardXLA(six.with_metaclass(
+    benchmark.ParameterizedBenchmark,
+    layer_benchmarks_test_base.LayerBenchmarksBase)):
+
+  _benchmark_parameters = [
+      ("Conv2D_small_shape", tf.keras.layers.Conv2D,
+       {"filters": 1, "kernel_size": 1, "activation": "relu"},
+       (1, 1, 1, 1), 10000),
+      ("Conv2D_normal_shape", tf.keras.layers.Conv2D,
+       {"filters": 1, "kernel_size": 1, "activation": "relu"},
+       (64, 28, 28, 3), 10000),
+      # TODO(b/153480400)
+      # ("LSTM_small_shape", tf.keras.layers.LSTM,
+      #  {"units": 1}, (1, 1, 1), 10000),
+      # ("LSTM_normal_shape", tf.keras.layers.LSTM,
+      #  {"units": 4}, (32, 10, 8), 10000),
+  ]
+
+  def benchmark_layer_call_backward_with_xla(
+      self, layer_cls, layer_args, input_shape, num_iters):
+    layer = layer_cls(**layer_args)
+    x = tf.ones(input_shape)
+    layer.call = tf.function(
+        layer.call, experimental_compile=True)
+
+    fn = functools.partial(_layer_call_backward, layer, x)
+    self.run_report(fn, num_iters)
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test_base.py b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test_base.py
new file mode 100644
index 00000000000..94595c95449
--- /dev/null
+++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test_base.py
@@ -0,0 +1,72 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Benchmark base to run and report Keras layers benchmark results."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+import tensorflow as tf
+
+from tensorflow.python.keras.benchmarks.layer_benchmarks import run_xprof
+
+
+class LayerBenchmarksBase(tf.test.Benchmark):
+  """Run and report benchmark results.
+
+  The first run is without any profiling to purly measure running time.
+  Second run is with xprof but no python trace.
+  Third run is with xprof and python trace.
+  Note: xprof runs fewer iterations, and the maximum iterations is 100.
+  """
+
+  def run_report(self, func, num_iters):
+    """Run and report benchmark results for different settings."""
+
+    # 0. Warm up.
+    func()
+
+    # 1. Run without profiling.
+    start = time.time()
+    for _ in range(num_iters):
+      func()
+    total_time = time.time() - start
+    us_mean_time = total_time * 1e6 / num_iters
+
+    metrics = [
+        {"name": "examples_per_sec",
+         "value": float("{0:.3f}".format(num_iters / total_time))},
+        {"name": "us_per_example",
+         "value": float("{0:.3f}".format(us_mean_time))}]
+
+    # 2. Run with xprof with no python trace.
+    num_iters_xprof = min(100, num_iters)
+    xprof_link, us_per_example = run_xprof.run_with_xprof(
+        func, num_iters_xprof, False)
+    # This xprof link will appear in the benchmark dashboard.
+    extras = {
+        "xprof_link": xprof_link,
+        "us_per_example_with_xprof": us_per_example
+    }
+
+    # 3. Run with xprof and python trace.
+    xprof_link, us_per_example = run_xprof.run_with_xprof(
+        func, num_iters_xprof, True)
+    extras["xprof_with_python_trace"] = xprof_link
+    extras["us_per_example_with_xprof_and_python"] = us_per_example
+
+    self.report_benchmark(
+        iters=num_iters, wall_time=us_mean_time, extras=extras, metrics=metrics)
diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/run_xprof.py b/tensorflow/python/keras/benchmarks/layer_benchmarks/run_xprof.py
new file mode 100644
index 00000000000..aef4d7b9877
--- /dev/null
+++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/run_xprof.py
@@ -0,0 +1,40 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import as _absolute_import
+from __future__ import division as _division
+from __future__ import print_function as _print_function
+
+import time
+import uuid
+
+from tensorflow.python.profiler import profiler_v2 as profiler
+
+def run_with_xprof(self, func, num_iters_xprof=100, enable_python_trace=True,
+                   logdir='/tmp/layer_benchmark_xprof/'):
+  suid = str(uuid.uuid4())
+  if enable_python_trace:
+    options = profiler.ProfilerOptions(python_tracer_level=1)
+    logdir = os.path.join(logdir, str(uuid.uuid4()) + "_with_python")
+  else:
+    options = profiler.ProfilerOptions(python_tracer_level=0)
+    logdir = os.path.join(logdir, suid)
+
+  start = time.time()
+  with profiler.Profile(logdir, options):
+    for _ in range(num_iters_xprof):
+      func()
+  total_time = time.time() - start
+  us_per_example = float("{0:.3f}".format(total_time * 1e6 / num_iters_xprof))
+  return logdir, us_per_example