Add fused batch norm to tf.layers.

PiperOrigin-RevId: 157893874
This commit is contained in:
Yao Zhang 2017-06-02 17:18:21 -07:00 committed by TensorFlower Gardener
parent f37d0ea47b
commit 675d36be0d
4 changed files with 168 additions and 10 deletions

View File

@ -3555,13 +3555,11 @@ py_test(
],
)
py_test(
cuda_py_test(
name = "layers_normalization_test",
size = "small",
srcs = ["layers/normalization_test.py"],
main = "layers/normalization_test.py",
srcs_version = "PY2AND3",
deps = [
additional_deps = [
":array_ops",
":client_testlib",
":framework_for_generated_wrappers",
@ -3571,6 +3569,7 @@ py_test(
":variables",
"//third_party/py/numpy",
],
main = "layers/normalization_test.py",
)
# -----------------------------------------------------------------------------

View File

@ -66,9 +66,6 @@ class BatchNormalization(base.Layer):
moving_variance_initializer: Initializer for the moving variance.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
name: A string, the name of the layer.
renorm: Whether to use Batch Renormalization
(https://arxiv.org/abs/1702.03275). This adds extra variables during
training. The inference is the same for either value of this parameter.
@ -82,6 +79,11 @@ class BatchNormalization(base.Layer):
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference.
fused: if `True`, use a faster, fused implementation based on
nn.fused_batch_norm. If `None`, use the fused implementation if possible.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
name: A string, the name of the layer.
"""
def __init__(self,
@ -99,6 +101,7 @@ class BatchNormalization(base.Layer):
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
fused=False,
trainable=True,
name=None,
**kwargs):
@ -116,6 +119,10 @@ class BatchNormalization(base.Layer):
self.beta_regularizer = beta_regularizer
self.gamma_regularizer = gamma_regularizer
self.renorm = renorm
self.fused = fused
if self.fused and renorm:
raise ValueError(
'Batch renorm is currently not supported with fused batch norm.')
if renorm:
renorm_clipping = renorm_clipping or {}
keys = ['rmax', 'rmin', 'dmax']
@ -130,6 +137,13 @@ class BatchNormalization(base.Layer):
if not input_shape.ndims:
raise ValueError('Input has undefined rank:', input_shape)
ndim = len(input_shape)
# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
# output back to its original shape accordingly.
if self.fused and ndim != 4:
raise ValueError(
'Only 4D inputs are currently supported with fused batch norm. '
'Consider reshaping the input to 4D and reshape the output back '
'to its original shape. Got input rank: ', ndim)
if self.axis < 0:
axis = ndim + self.axis
else:
@ -137,6 +151,20 @@ class BatchNormalization(base.Layer):
if axis < 0 or axis >= ndim:
raise ValueError('Value of `axis` argument ' + str(self.axis) +
' is out of range for input with rank ' + str(ndim))
if self.fused is None:
self.fused = not self.renorm and ndim == 4 and axis in [1, 3]
if self.fused:
if axis == 1:
self._data_format = 'NCHW'
elif axis == 3:
self._data_format = 'NHWC'
else:
raise ValueError(
'Only axis 1 and 3 are currently supported dimensions for '
'fused batch norm. Got `axis` dimension: ', axis)
param_dim = input_shape[axis]
if not param_dim.value:
raise ValueError('Input has undefined `axis` dimension. Input shape: ',
@ -152,6 +180,8 @@ class BatchNormalization(base.Layer):
trainable=True)
else:
self.beta = None
if self.fused:
self._beta_const = array_ops.constant(0.0, shape=(param_dim,))
if self.scale:
self.gamma = self.add_variable(name='gamma',
shape=(param_dim,),
@ -160,6 +190,8 @@ class BatchNormalization(base.Layer):
trainable=True)
else:
self.gamma = None
if self.fused:
self._gamma_const = array_ops.constant(1.0, shape=(param_dim,))
# Disable variable partitioning when creating the moving mean and variance
partitioner = self._scope.partitioner
@ -205,6 +237,45 @@ class BatchNormalization(base.Layer):
self._scope.set_partitioner(partitioner)
self.built = True
def _fused_batch_norm(self, inputs, training):
"""Returns the output of fused batch norm."""
beta = self.beta if self.center else self._beta_const
gamma = self.gamma if self.scale else self._gamma_const
def _fused_batch_norm_training():
return nn.fused_batch_norm(
inputs,
gamma,
beta,
epsilon=self.epsilon,
data_format=self._data_format)
def _fused_batch_norm_inference():
return nn.fused_batch_norm(
inputs,
gamma,
beta,
mean=self.moving_mean,
variance=self.moving_variance,
epsilon=self.epsilon,
is_training=False,
data_format=self._data_format)
output, mean, variance = utils.smart_cond(
training, _fused_batch_norm_training, _fused_batch_norm_inference)
training_value = utils.constant_value(training)
if training_value is not False:
decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
mean_update = moving_averages.assign_moving_average(
self.moving_mean, mean, decay, zero_debias=False)
variance_update = moving_averages.assign_moving_average(
self.moving_variance, variance, decay, zero_debias=False)
self.add_update(mean_update, inputs=inputs)
self.add_update(variance_update, inputs=inputs)
return output
def _renorm_correction_and_moments(self, mean, variance, training):
"""Returns the correction and update values for renorm."""
stddev = math_ops.sqrt(variance + self.epsilon)
@ -265,6 +336,9 @@ class BatchNormalization(base.Layer):
return (r, d, new_mean, new_variance)
def call(self, inputs, training=False):
if self.fused:
return self._fused_batch_norm(inputs, training=training)
# First, compute the axes along which to reduce the mean / variance,
# as well as the broadcast shape to be used for all parameters.
input_shape = inputs.get_shape()
@ -353,7 +427,8 @@ def batch_normalization(inputs,
reuse=None,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99):
renorm_momentum=0.99,
fused=False):
"""Functional interface for the batch normalization layer.
Reference: http://arxiv.org/abs/1502.03167
@ -415,6 +490,8 @@ def batch_normalization(inputs,
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference.
fused: if `True`, use a faster, fused implementation based on
nn.fused_batch_norm. If `None`, use the fused implementation if possible.
Returns:
Output tensor.
@ -431,10 +508,11 @@ def batch_normalization(inputs,
moving_variance_initializer=moving_variance_initializer,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
trainable=trainable,
renorm=renorm,
renorm_clipping=renorm_clipping,
renorm_momentum=renorm_momentum,
fused=fused,
trainable=trainable,
name=name,
_reuse=reuse,
_scope=name)

View File

@ -262,6 +262,87 @@ class BNTest(test.TestCase):
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
def test4DInputAxis3Fused(self):
epsilon = 1e-3
bn = normalization_layers.BatchNormalization(
axis=3, epsilon=epsilon, momentum=0.9, fused=True)
inputs = variables.Variable(
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
with self.test_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
for _ in range(100):
np_output, _, _ = sess.run(
[outputs] + bn.updates, feed_dict={training: True})
# Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
np_inputs = sess.run(inputs)
mean = np.mean(np_inputs, axis=(0, 1, 2))
std = np.std(np_inputs, axis=(0, 1, 2))
variance = np.square(std)
self.assertAllClose(mean, moving_mean, atol=1e-2)
self.assertAllClose(variance, moving_var, atol=1e-2)
# Test inference with placeholder learning phase.
np_output = sess.run(outputs, feed_dict={training: False})
# Verify that the axis is normalized during inference.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
def test4DInputAxis1Fused(self):
if test.is_gpu_available(cuda_only=True):
epsilon = 1e-3
bn = normalization_layers.BatchNormalization(
axis=1, epsilon=epsilon, momentum=0.9, fused=True)
inputs = variables.Variable(
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
with self.test_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
np_beta = np.reshape(np_beta, (1, 4, 1, 1))
for _ in range(100):
np_output, _, _ = sess.run(
[outputs] + bn.updates, feed_dict={training: True})
# Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
np_inputs = sess.run(inputs)
mean = np.mean(np_inputs, axis=(0, 2, 3))
std = np.std(np_inputs, axis=(0, 2, 3))
variance = np.square(std)
self.assertAllClose(mean, moving_mean, atol=1e-2)
self.assertAllClose(variance, moving_var, atol=1e-2)
# Test inference with placeholder learning phase.
np_output = sess.run(outputs, feed_dict={training: False})
# Verify that the axis is normalized during inference.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
def testNegativeAxis(self):
epsilon = 1e-3
bn = normalization_layers.BatchNormalization(

View File

@ -14,7 +14,7 @@ tf_module {
}
member_method {
name: "batch_normalization"
argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\'], "
argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'False\'], "
}
member_method {
name: "conv1d"