Add fused batch norm to tf.layers.

PiperOrigin-RevId: 157893874
This commit is contained in:
Yao Zhang 2017-06-02 17:18:21 -07:00 committed by TensorFlower Gardener
parent f37d0ea47b
commit 675d36be0d
4 changed files with 168 additions and 10 deletions

View File

@ -3555,13 +3555,11 @@ py_test(
], ],
) )
py_test( cuda_py_test(
name = "layers_normalization_test", name = "layers_normalization_test",
size = "small", size = "small",
srcs = ["layers/normalization_test.py"], srcs = ["layers/normalization_test.py"],
main = "layers/normalization_test.py", additional_deps = [
srcs_version = "PY2AND3",
deps = [
":array_ops", ":array_ops",
":client_testlib", ":client_testlib",
":framework_for_generated_wrappers", ":framework_for_generated_wrappers",
@ -3571,6 +3569,7 @@ py_test(
":variables", ":variables",
"//third_party/py/numpy", "//third_party/py/numpy",
], ],
main = "layers/normalization_test.py",
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@ -66,9 +66,6 @@ class BatchNormalization(base.Layer):
moving_variance_initializer: Initializer for the moving variance. moving_variance_initializer: Initializer for the moving variance.
beta_regularizer: Optional regularizer for the beta weight. beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight. gamma_regularizer: Optional regularizer for the gamma weight.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
name: A string, the name of the layer.
renorm: Whether to use Batch Renormalization renorm: Whether to use Batch Renormalization
(https://arxiv.org/abs/1702.03275). This adds extra variables during (https://arxiv.org/abs/1702.03275). This adds extra variables during
training. The inference is the same for either value of this parameter. training. The inference is the same for either value of this parameter.
@ -82,6 +79,11 @@ class BatchNormalization(base.Layer):
and should be neither too small (which would add noise) nor too large and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied (which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference. to get the means and variances for inference.
fused: if `True`, use a faster, fused implementation based on
nn.fused_batch_norm. If `None`, use the fused implementation if possible.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
name: A string, the name of the layer.
""" """
def __init__(self, def __init__(self,
@ -99,6 +101,7 @@ class BatchNormalization(base.Layer):
renorm=False, renorm=False,
renorm_clipping=None, renorm_clipping=None,
renorm_momentum=0.99, renorm_momentum=0.99,
fused=False,
trainable=True, trainable=True,
name=None, name=None,
**kwargs): **kwargs):
@ -116,6 +119,10 @@ class BatchNormalization(base.Layer):
self.beta_regularizer = beta_regularizer self.beta_regularizer = beta_regularizer
self.gamma_regularizer = gamma_regularizer self.gamma_regularizer = gamma_regularizer
self.renorm = renorm self.renorm = renorm
self.fused = fused
if self.fused and renorm:
raise ValueError(
'Batch renorm is currently not supported with fused batch norm.')
if renorm: if renorm:
renorm_clipping = renorm_clipping or {} renorm_clipping = renorm_clipping or {}
keys = ['rmax', 'rmin', 'dmax'] keys = ['rmax', 'rmin', 'dmax']
@ -130,6 +137,13 @@ class BatchNormalization(base.Layer):
if not input_shape.ndims: if not input_shape.ndims:
raise ValueError('Input has undefined rank:', input_shape) raise ValueError('Input has undefined rank:', input_shape)
ndim = len(input_shape) ndim = len(input_shape)
# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
# output back to its original shape accordingly.
if self.fused and ndim != 4:
raise ValueError(
'Only 4D inputs are currently supported with fused batch norm. '
'Consider reshaping the input to 4D and reshape the output back '
'to its original shape. Got input rank: ', ndim)
if self.axis < 0: if self.axis < 0:
axis = ndim + self.axis axis = ndim + self.axis
else: else:
@ -137,6 +151,20 @@ class BatchNormalization(base.Layer):
if axis < 0 or axis >= ndim: if axis < 0 or axis >= ndim:
raise ValueError('Value of `axis` argument ' + str(self.axis) + raise ValueError('Value of `axis` argument ' + str(self.axis) +
' is out of range for input with rank ' + str(ndim)) ' is out of range for input with rank ' + str(ndim))
if self.fused is None:
self.fused = not self.renorm and ndim == 4 and axis in [1, 3]
if self.fused:
if axis == 1:
self._data_format = 'NCHW'
elif axis == 3:
self._data_format = 'NHWC'
else:
raise ValueError(
'Only axis 1 and 3 are currently supported dimensions for '
'fused batch norm. Got `axis` dimension: ', axis)
param_dim = input_shape[axis] param_dim = input_shape[axis]
if not param_dim.value: if not param_dim.value:
raise ValueError('Input has undefined `axis` dimension. Input shape: ', raise ValueError('Input has undefined `axis` dimension. Input shape: ',
@ -152,6 +180,8 @@ class BatchNormalization(base.Layer):
trainable=True) trainable=True)
else: else:
self.beta = None self.beta = None
if self.fused:
self._beta_const = array_ops.constant(0.0, shape=(param_dim,))
if self.scale: if self.scale:
self.gamma = self.add_variable(name='gamma', self.gamma = self.add_variable(name='gamma',
shape=(param_dim,), shape=(param_dim,),
@ -160,6 +190,8 @@ class BatchNormalization(base.Layer):
trainable=True) trainable=True)
else: else:
self.gamma = None self.gamma = None
if self.fused:
self._gamma_const = array_ops.constant(1.0, shape=(param_dim,))
# Disable variable partitioning when creating the moving mean and variance # Disable variable partitioning when creating the moving mean and variance
partitioner = self._scope.partitioner partitioner = self._scope.partitioner
@ -205,6 +237,45 @@ class BatchNormalization(base.Layer):
self._scope.set_partitioner(partitioner) self._scope.set_partitioner(partitioner)
self.built = True self.built = True
def _fused_batch_norm(self, inputs, training):
"""Returns the output of fused batch norm."""
beta = self.beta if self.center else self._beta_const
gamma = self.gamma if self.scale else self._gamma_const
def _fused_batch_norm_training():
return nn.fused_batch_norm(
inputs,
gamma,
beta,
epsilon=self.epsilon,
data_format=self._data_format)
def _fused_batch_norm_inference():
return nn.fused_batch_norm(
inputs,
gamma,
beta,
mean=self.moving_mean,
variance=self.moving_variance,
epsilon=self.epsilon,
is_training=False,
data_format=self._data_format)
output, mean, variance = utils.smart_cond(
training, _fused_batch_norm_training, _fused_batch_norm_inference)
training_value = utils.constant_value(training)
if training_value is not False:
decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
mean_update = moving_averages.assign_moving_average(
self.moving_mean, mean, decay, zero_debias=False)
variance_update = moving_averages.assign_moving_average(
self.moving_variance, variance, decay, zero_debias=False)
self.add_update(mean_update, inputs=inputs)
self.add_update(variance_update, inputs=inputs)
return output
def _renorm_correction_and_moments(self, mean, variance, training): def _renorm_correction_and_moments(self, mean, variance, training):
"""Returns the correction and update values for renorm.""" """Returns the correction and update values for renorm."""
stddev = math_ops.sqrt(variance + self.epsilon) stddev = math_ops.sqrt(variance + self.epsilon)
@ -265,6 +336,9 @@ class BatchNormalization(base.Layer):
return (r, d, new_mean, new_variance) return (r, d, new_mean, new_variance)
def call(self, inputs, training=False): def call(self, inputs, training=False):
if self.fused:
return self._fused_batch_norm(inputs, training=training)
# First, compute the axes along which to reduce the mean / variance, # First, compute the axes along which to reduce the mean / variance,
# as well as the broadcast shape to be used for all parameters. # as well as the broadcast shape to be used for all parameters.
input_shape = inputs.get_shape() input_shape = inputs.get_shape()
@ -353,7 +427,8 @@ def batch_normalization(inputs,
reuse=None, reuse=None,
renorm=False, renorm=False,
renorm_clipping=None, renorm_clipping=None,
renorm_momentum=0.99): renorm_momentum=0.99,
fused=False):
"""Functional interface for the batch normalization layer. """Functional interface for the batch normalization layer.
Reference: http://arxiv.org/abs/1502.03167 Reference: http://arxiv.org/abs/1502.03167
@ -415,6 +490,8 @@ def batch_normalization(inputs,
and should be neither too small (which would add noise) nor too large and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied (which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference. to get the means and variances for inference.
fused: if `True`, use a faster, fused implementation based on
nn.fused_batch_norm. If `None`, use the fused implementation if possible.
Returns: Returns:
Output tensor. Output tensor.
@ -431,10 +508,11 @@ def batch_normalization(inputs,
moving_variance_initializer=moving_variance_initializer, moving_variance_initializer=moving_variance_initializer,
beta_regularizer=beta_regularizer, beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer, gamma_regularizer=gamma_regularizer,
trainable=trainable,
renorm=renorm, renorm=renorm,
renorm_clipping=renorm_clipping, renorm_clipping=renorm_clipping,
renorm_momentum=renorm_momentum, renorm_momentum=renorm_momentum,
fused=fused,
trainable=trainable,
name=name, name=name,
_reuse=reuse, _reuse=reuse,
_scope=name) _scope=name)

View File

@ -262,6 +262,87 @@ class BNTest(test.TestCase):
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
def test4DInputAxis3Fused(self):
epsilon = 1e-3
bn = normalization_layers.BatchNormalization(
axis=3, epsilon=epsilon, momentum=0.9, fused=True)
inputs = variables.Variable(
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
with self.test_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
for _ in range(100):
np_output, _, _ = sess.run(
[outputs] + bn.updates, feed_dict={training: True})
# Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
np_inputs = sess.run(inputs)
mean = np.mean(np_inputs, axis=(0, 1, 2))
std = np.std(np_inputs, axis=(0, 1, 2))
variance = np.square(std)
self.assertAllClose(mean, moving_mean, atol=1e-2)
self.assertAllClose(variance, moving_var, atol=1e-2)
# Test inference with placeholder learning phase.
np_output = sess.run(outputs, feed_dict={training: False})
# Verify that the axis is normalized during inference.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
def test4DInputAxis1Fused(self):
if test.is_gpu_available(cuda_only=True):
epsilon = 1e-3
bn = normalization_layers.BatchNormalization(
axis=1, epsilon=epsilon, momentum=0.9, fused=True)
inputs = variables.Variable(
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
with self.test_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
np_beta = np.reshape(np_beta, (1, 4, 1, 1))
for _ in range(100):
np_output, _, _ = sess.run(
[outputs] + bn.updates, feed_dict={training: True})
# Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
np_inputs = sess.run(inputs)
mean = np.mean(np_inputs, axis=(0, 2, 3))
std = np.std(np_inputs, axis=(0, 2, 3))
variance = np.square(std)
self.assertAllClose(mean, moving_mean, atol=1e-2)
self.assertAllClose(variance, moving_var, atol=1e-2)
# Test inference with placeholder learning phase.
np_output = sess.run(outputs, feed_dict={training: False})
# Verify that the axis is normalized during inference.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
def testNegativeAxis(self): def testNegativeAxis(self):
epsilon = 1e-3 epsilon = 1e-3
bn = normalization_layers.BatchNormalization( bn = normalization_layers.BatchNormalization(

View File

@ -14,7 +14,7 @@ tf_module {
} }
member_method { member_method {
name: "batch_normalization" name: "batch_normalization"
argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\'], " argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'False\'], "
} }
member_method { member_method {
name: "conv1d" name: "conv1d"