STT-tensorflow/tensorflow/python/ops/batch_norm_benchmark.py
Mark Daoust fd28b784b2 Apply tf1-tf2 renames to tensorflow/python/ops docstrings and comments.
No code changes, only doc-strings and comments.

PiperOrigin-RevId: 244270926
2019-04-18 16:11:35 -07:00

264 lines
10 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""End-to-end benchmark for batch normalization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import time
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
def batch_norm_op(tensor, mean, variance, beta, gamma, scale):
"""Fused kernel for batch normalization."""
# _batch_norm_with_global_normalization is deprecated in v9
test_util.set_producer_version(ops.get_default_graph(), 8)
# pylint: disable=protected-access
return gen_nn_ops._batch_norm_with_global_normalization(
tensor, mean, variance, beta, gamma, 0.001, scale)
# pylint: enable=protected-access
# Note that the naive implementation is much slower:
# batch_norm = (tensor - mean) * tf.math.rsqrt(variance + 0.001)
# if scale:
# batch_norm *= gamma
# return batch_norm + beta
def batch_norm_py(tensor, mean, variance, beta, gamma, scale):
"""Python implementation of batch normalization."""
return nn_impl.batch_normalization(tensor, mean, variance, beta, gamma if
scale else None, 0.001)
def batch_norm_slow(tensor, mean, variance, beta, gamma, scale):
batch_norm = (tensor - mean) * math_ops.rsqrt(variance + 0.001)
if scale:
batch_norm *= gamma
return batch_norm + beta
def build_graph(device, input_shape, axes, num_layers, mode, scale, train):
"""Build a graph containing a sequence of batch normalizations.
Args:
device: string, the device to run on.
input_shape: shape of the input tensor.
axes: axes that are to be normalized across.
num_layers: number of batch normalization layers in the graph.
mode: "op", "py" or "slow" depending on the implementation.
scale: scale after normalization.
train: if true, also run backprop.
Returns:
An array of tensors to run()
"""
moment_shape = []
keep_dims = mode == "py" or mode == "slow"
if keep_dims:
for axis in range(len(input_shape)):
if axis in axes:
moment_shape.append(1)
else:
moment_shape.append(input_shape[axis])
else:
for axis in range(len(input_shape)):
if axis not in axes:
moment_shape.append(input_shape[axis])
with ops.device("/%s:0" % device):
tensor = variables.Variable(random_ops.truncated_normal(input_shape))
for _ in range(num_layers):
if train:
mean, variance = nn_impl.moments(tensor, axes, keep_dims=keep_dims)
else:
mean = array_ops.zeros(moment_shape)
variance = array_ops.ones(moment_shape)
beta = variables.Variable(array_ops.zeros(moment_shape))
gamma = variables.Variable(constant_op.constant(1.0, shape=moment_shape))
if mode == "py":
tensor = batch_norm_py(tensor, mean, variance, beta, gamma, scale)
elif mode == "op":
tensor = batch_norm_op(tensor, mean, variance, beta, gamma, scale)
elif mode == "slow":
tensor = batch_norm_slow(tensor, mean, variance, beta, gamma, scale)
if train:
return gradients_impl.gradients([tensor], variables.trainable_variables())
else:
return [tensor]
def print_difference(mode, t1, t2):
"""Print the difference in timing between two runs."""
difference = (t2 - t1) / t1 * 100.0
print("=== %s: %.1f%% ===" % (mode, difference))
class BatchNormBenchmark(test.Benchmark):
"""Benchmark batch normalization."""
def _run_graph(self, device, input_shape, axes, num_layers, mode, scale,
train, num_iters):
"""Run the graph and print its execution time.
Args:
device: string, the device to run on.
input_shape: shape of the input tensor.
axes: axes that are to be normalized across.
num_layers: number of batch normalization layers in the graph.
mode: "op", "py" or "slow" depending on the implementation.
scale: scale after normalization.
train: if true, also run backprop.
num_iters: number of steps to run.
Returns:
The duration of the run in seconds.
"""
graph = ops.Graph()
with graph.as_default():
outputs = build_graph(device, input_shape, axes, num_layers, mode, scale,
train)
with session_lib.Session(graph=graph) as session:
variables.global_variables_initializer().run()
_ = session.run([out.op for out in outputs]) # warm up.
start_time = time.time()
for _ in range(num_iters):
_ = session.run([out.op for out in outputs])
duration = time.time() - start_time
print("%s shape:%d/%d #layers:%d mode:%s scale:%r train:%r - %f secs" %
(device, len(input_shape), len(axes), num_layers, mode, scale, train,
duration / num_iters))
name_template = (
"batch_norm_{device}_input_shape_{shape}_axes_{axes}_mode_{mode}_"
"layers_{num_layers}_scale_{scale}_"
"train_{train}")
self.report_benchmark(
name=name_template.format(
device=device,
mode=mode,
num_layers=num_layers,
scale=scale,
train=train,
shape=str(input_shape).replace(" ", ""),
axes=str(axes)).replace(" ", ""),
iters=num_iters,
wall_time=duration / num_iters)
return duration
def benchmark_batch_norm(self):
print("Forward convolution (lower layers).")
shape = [8, 128, 128, 32]
axes = [0, 1, 2]
t1 = self._run_graph("cpu", shape, axes, 10, "op", True, False, 5)
t2 = self._run_graph("cpu", shape, axes, 10, "py", True, False, 5)
t3 = self._run_graph("cpu", shape, axes, 10, "slow", True, False, 5)
print_difference("op vs py", t1, t2)
print_difference("py vs slow", t2, t3)
if FLAGS.use_gpu:
t1 = self._run_graph("gpu", shape, axes, 10, "op", True, False, 50)
t2 = self._run_graph("gpu", shape, axes, 10, "py", True, False, 50)
t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, False, 50)
print_difference("op vs py", t1, t2)
print_difference("py vs slow", t2, t3)
print("Forward/backward convolution (lower layers).")
t1 = self._run_graph("cpu", shape, axes, 10, "op", True, True, 5)
t2 = self._run_graph("cpu", shape, axes, 10, "py", True, True, 5)
t3 = self._run_graph("cpu", shape, axes, 10, "slow", True, True, 5)
print_difference("op vs py", t1, t2)
print_difference("py vs slow", t2, t3)
if FLAGS.use_gpu:
t1 = self._run_graph("gpu", shape, axes, 10, "op", True, True, 50)
t2 = self._run_graph("gpu", shape, axes, 10, "py", True, True, 50)
t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 50)
print_difference("op vs py", t1, t2)
print_difference("py vs slow", t2, t3)
print("Forward convolution (higher layers).")
shape = [256, 17, 17, 32]
axes = [0, 1, 2]
t1 = self._run_graph("cpu", shape, axes, 10, "op", True, False, 5)
t2 = self._run_graph("cpu", shape, axes, 10, "py", True, False, 5)
t3 = self._run_graph("cpu", shape, axes, 10, "slow", True, False, 5)
print_difference("op vs py", t1, t2)
print_difference("py vs slow", t2, t3)
if FLAGS.use_gpu:
t1 = self._run_graph("gpu", shape, axes, 10, "op", True, False, 50)
t2 = self._run_graph("gpu", shape, axes, 10, "py", True, False, 50)
t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, False, 50)
print_difference("op vs py", t1, t2)
print_difference("py vs slow", t2, t3)
print("Forward/backward convolution (higher layers).")
t1 = self._run_graph("cpu", shape, axes, 10, "op", True, True, 5)
t2 = self._run_graph("cpu", shape, axes, 10, "py", True, True, 5)
t3 = self._run_graph("cpu", shape, axes, 10, "slow", True, True, 5)
print_difference("op vs py", t1, t2)
print_difference("py vs slow", t2, t3)
if FLAGS.use_gpu:
t1 = self._run_graph("gpu", shape, axes, 10, "op", True, True, 50)
t2 = self._run_graph("gpu", shape, axes, 10, "py", True, True, 50)
t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 50)
print_difference("op vs py", t1, t2)
print_difference("py vs slow", t2, t3)
print("Forward fully-connected.")
shape = [1024, 32]
axes = [0]
t1 = self._run_graph("cpu", shape, axes, 10, "py", True, False, 5)
t2 = self._run_graph("cpu", shape, axes, 10, "slow", True, False, 5)
print_difference("py vs slow", t1, t2)
if FLAGS.use_gpu:
t1 = self._run_graph("gpu", shape, axes, 10, "py", True, False, 50)
t2 = self._run_graph("gpu", shape, axes, 10, "slow", True, False, 50)
print_difference("py vs slow", t1, t2)
print("Forward/backward fully-connected.")
t1 = self._run_graph("cpu", shape, axes, 10, "py", True, True, 50)
t2 = self._run_graph("cpu", shape, axes, 10, "slow", True, True, 50)
print_difference("py vs slow", t1, t2)
if FLAGS.use_gpu:
t1 = self._run_graph("gpu", shape, axes, 10, "py", True, True, 5)
t2 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 5)
print_difference("py vs slow", t1, t2)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--use_gpu",
type="bool",
nargs="?",
const=True,
default=True,
help="Run GPU benchmarks."
)
global FLAGS # pylint:disable=global-at-module-level
FLAGS, unparsed = parser.parse_known_args()
test.main(argv=[sys.argv[0]] + unparsed)