Add fused batch norm to tf.layers.
PiperOrigin-RevId: 157893874
This commit is contained in:
parent
f37d0ea47b
commit
675d36be0d
@ -3555,13 +3555,11 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
cuda_py_test(
|
||||
name = "layers_normalization_test",
|
||||
size = "small",
|
||||
srcs = ["layers/normalization_test.py"],
|
||||
main = "layers/normalization_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
additional_deps = [
|
||||
":array_ops",
|
||||
":client_testlib",
|
||||
":framework_for_generated_wrappers",
|
||||
@ -3571,6 +3569,7 @@ py_test(
|
||||
":variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
main = "layers/normalization_test.py",
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
@ -66,9 +66,6 @@ class BatchNormalization(base.Layer):
|
||||
moving_variance_initializer: Initializer for the moving variance.
|
||||
beta_regularizer: Optional regularizer for the beta weight.
|
||||
gamma_regularizer: Optional regularizer for the gamma weight.
|
||||
trainable: Boolean, if `True` also add variables to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
||||
name: A string, the name of the layer.
|
||||
renorm: Whether to use Batch Renormalization
|
||||
(https://arxiv.org/abs/1702.03275). This adds extra variables during
|
||||
training. The inference is the same for either value of this parameter.
|
||||
@ -82,6 +79,11 @@ class BatchNormalization(base.Layer):
|
||||
and should be neither too small (which would add noise) nor too large
|
||||
(which would give stale estimates). Note that `momentum` is still applied
|
||||
to get the means and variances for inference.
|
||||
fused: if `True`, use a faster, fused implementation based on
|
||||
nn.fused_batch_norm. If `None`, use the fused implementation if possible.
|
||||
trainable: Boolean, if `True` also add variables to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
||||
name: A string, the name of the layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -99,6 +101,7 @@ class BatchNormalization(base.Layer):
|
||||
renorm=False,
|
||||
renorm_clipping=None,
|
||||
renorm_momentum=0.99,
|
||||
fused=False,
|
||||
trainable=True,
|
||||
name=None,
|
||||
**kwargs):
|
||||
@ -116,6 +119,10 @@ class BatchNormalization(base.Layer):
|
||||
self.beta_regularizer = beta_regularizer
|
||||
self.gamma_regularizer = gamma_regularizer
|
||||
self.renorm = renorm
|
||||
self.fused = fused
|
||||
if self.fused and renorm:
|
||||
raise ValueError(
|
||||
'Batch renorm is currently not supported with fused batch norm.')
|
||||
if renorm:
|
||||
renorm_clipping = renorm_clipping or {}
|
||||
keys = ['rmax', 'rmin', 'dmax']
|
||||
@ -130,6 +137,13 @@ class BatchNormalization(base.Layer):
|
||||
if not input_shape.ndims:
|
||||
raise ValueError('Input has undefined rank:', input_shape)
|
||||
ndim = len(input_shape)
|
||||
# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
|
||||
# output back to its original shape accordingly.
|
||||
if self.fused and ndim != 4:
|
||||
raise ValueError(
|
||||
'Only 4D inputs are currently supported with fused batch norm. '
|
||||
'Consider reshaping the input to 4D and reshape the output back '
|
||||
'to its original shape. Got input rank: ', ndim)
|
||||
if self.axis < 0:
|
||||
axis = ndim + self.axis
|
||||
else:
|
||||
@ -137,6 +151,20 @@ class BatchNormalization(base.Layer):
|
||||
if axis < 0 or axis >= ndim:
|
||||
raise ValueError('Value of `axis` argument ' + str(self.axis) +
|
||||
' is out of range for input with rank ' + str(ndim))
|
||||
|
||||
if self.fused is None:
|
||||
self.fused = not self.renorm and ndim == 4 and axis in [1, 3]
|
||||
|
||||
if self.fused:
|
||||
if axis == 1:
|
||||
self._data_format = 'NCHW'
|
||||
elif axis == 3:
|
||||
self._data_format = 'NHWC'
|
||||
else:
|
||||
raise ValueError(
|
||||
'Only axis 1 and 3 are currently supported dimensions for '
|
||||
'fused batch norm. Got `axis` dimension: ', axis)
|
||||
|
||||
param_dim = input_shape[axis]
|
||||
if not param_dim.value:
|
||||
raise ValueError('Input has undefined `axis` dimension. Input shape: ',
|
||||
@ -152,6 +180,8 @@ class BatchNormalization(base.Layer):
|
||||
trainable=True)
|
||||
else:
|
||||
self.beta = None
|
||||
if self.fused:
|
||||
self._beta_const = array_ops.constant(0.0, shape=(param_dim,))
|
||||
if self.scale:
|
||||
self.gamma = self.add_variable(name='gamma',
|
||||
shape=(param_dim,),
|
||||
@ -160,6 +190,8 @@ class BatchNormalization(base.Layer):
|
||||
trainable=True)
|
||||
else:
|
||||
self.gamma = None
|
||||
if self.fused:
|
||||
self._gamma_const = array_ops.constant(1.0, shape=(param_dim,))
|
||||
|
||||
# Disable variable partitioning when creating the moving mean and variance
|
||||
partitioner = self._scope.partitioner
|
||||
@ -205,6 +237,45 @@ class BatchNormalization(base.Layer):
|
||||
self._scope.set_partitioner(partitioner)
|
||||
self.built = True
|
||||
|
||||
def _fused_batch_norm(self, inputs, training):
|
||||
"""Returns the output of fused batch norm."""
|
||||
beta = self.beta if self.center else self._beta_const
|
||||
gamma = self.gamma if self.scale else self._gamma_const
|
||||
|
||||
def _fused_batch_norm_training():
|
||||
return nn.fused_batch_norm(
|
||||
inputs,
|
||||
gamma,
|
||||
beta,
|
||||
epsilon=self.epsilon,
|
||||
data_format=self._data_format)
|
||||
|
||||
def _fused_batch_norm_inference():
|
||||
return nn.fused_batch_norm(
|
||||
inputs,
|
||||
gamma,
|
||||
beta,
|
||||
mean=self.moving_mean,
|
||||
variance=self.moving_variance,
|
||||
epsilon=self.epsilon,
|
||||
is_training=False,
|
||||
data_format=self._data_format)
|
||||
|
||||
output, mean, variance = utils.smart_cond(
|
||||
training, _fused_batch_norm_training, _fused_batch_norm_inference)
|
||||
|
||||
training_value = utils.constant_value(training)
|
||||
if training_value is not False:
|
||||
decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
|
||||
mean_update = moving_averages.assign_moving_average(
|
||||
self.moving_mean, mean, decay, zero_debias=False)
|
||||
variance_update = moving_averages.assign_moving_average(
|
||||
self.moving_variance, variance, decay, zero_debias=False)
|
||||
self.add_update(mean_update, inputs=inputs)
|
||||
self.add_update(variance_update, inputs=inputs)
|
||||
|
||||
return output
|
||||
|
||||
def _renorm_correction_and_moments(self, mean, variance, training):
|
||||
"""Returns the correction and update values for renorm."""
|
||||
stddev = math_ops.sqrt(variance + self.epsilon)
|
||||
@ -265,6 +336,9 @@ class BatchNormalization(base.Layer):
|
||||
return (r, d, new_mean, new_variance)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
if self.fused:
|
||||
return self._fused_batch_norm(inputs, training=training)
|
||||
|
||||
# First, compute the axes along which to reduce the mean / variance,
|
||||
# as well as the broadcast shape to be used for all parameters.
|
||||
input_shape = inputs.get_shape()
|
||||
@ -353,7 +427,8 @@ def batch_normalization(inputs,
|
||||
reuse=None,
|
||||
renorm=False,
|
||||
renorm_clipping=None,
|
||||
renorm_momentum=0.99):
|
||||
renorm_momentum=0.99,
|
||||
fused=False):
|
||||
"""Functional interface for the batch normalization layer.
|
||||
|
||||
Reference: http://arxiv.org/abs/1502.03167
|
||||
@ -415,6 +490,8 @@ def batch_normalization(inputs,
|
||||
and should be neither too small (which would add noise) nor too large
|
||||
(which would give stale estimates). Note that `momentum` is still applied
|
||||
to get the means and variances for inference.
|
||||
fused: if `True`, use a faster, fused implementation based on
|
||||
nn.fused_batch_norm. If `None`, use the fused implementation if possible.
|
||||
|
||||
Returns:
|
||||
Output tensor.
|
||||
@ -431,10 +508,11 @@ def batch_normalization(inputs,
|
||||
moving_variance_initializer=moving_variance_initializer,
|
||||
beta_regularizer=beta_regularizer,
|
||||
gamma_regularizer=gamma_regularizer,
|
||||
trainable=trainable,
|
||||
renorm=renorm,
|
||||
renorm_clipping=renorm_clipping,
|
||||
renorm_momentum=renorm_momentum,
|
||||
fused=fused,
|
||||
trainable=trainable,
|
||||
name=name,
|
||||
_reuse=reuse,
|
||||
_scope=name)
|
||||
|
@ -262,6 +262,87 @@ class BNTest(test.TestCase):
|
||||
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||
|
||||
def test4DInputAxis3Fused(self):
|
||||
epsilon = 1e-3
|
||||
bn = normalization_layers.BatchNormalization(
|
||||
axis=3, epsilon=epsilon, momentum=0.9, fused=True)
|
||||
inputs = variables.Variable(
|
||||
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
|
||||
training = array_ops.placeholder(dtype='bool')
|
||||
outputs = bn.apply(inputs, training=training)
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Test training with placeholder learning phase.
|
||||
sess.run(variables.global_variables_initializer())
|
||||
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
|
||||
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
|
||||
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
|
||||
for _ in range(100):
|
||||
np_output, _, _ = sess.run(
|
||||
[outputs] + bn.updates, feed_dict={training: True})
|
||||
# Verify that the axis is normalized during training.
|
||||
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||
|
||||
# Verify that the statistics are updated during training.
|
||||
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
|
||||
np_inputs = sess.run(inputs)
|
||||
mean = np.mean(np_inputs, axis=(0, 1, 2))
|
||||
std = np.std(np_inputs, axis=(0, 1, 2))
|
||||
variance = np.square(std)
|
||||
self.assertAllClose(mean, moving_mean, atol=1e-2)
|
||||
self.assertAllClose(variance, moving_var, atol=1e-2)
|
||||
|
||||
# Test inference with placeholder learning phase.
|
||||
np_output = sess.run(outputs, feed_dict={training: False})
|
||||
|
||||
# Verify that the axis is normalized during inference.
|
||||
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||
|
||||
def test4DInputAxis1Fused(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
epsilon = 1e-3
|
||||
bn = normalization_layers.BatchNormalization(
|
||||
axis=1, epsilon=epsilon, momentum=0.9, fused=True)
|
||||
inputs = variables.Variable(
|
||||
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
|
||||
training = array_ops.placeholder(dtype='bool')
|
||||
outputs = bn.apply(inputs, training=training)
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Test training with placeholder learning phase.
|
||||
sess.run(variables.global_variables_initializer())
|
||||
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
|
||||
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
|
||||
np_beta = np.reshape(np_beta, (1, 4, 1, 1))
|
||||
for _ in range(100):
|
||||
np_output, _, _ = sess.run(
|
||||
[outputs] + bn.updates, feed_dict={training: True})
|
||||
# Verify that the axis is normalized during training.
|
||||
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||
|
||||
# Verify that the statistics are updated during training.
|
||||
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
|
||||
np_inputs = sess.run(inputs)
|
||||
mean = np.mean(np_inputs, axis=(0, 2, 3))
|
||||
std = np.std(np_inputs, axis=(0, 2, 3))
|
||||
variance = np.square(std)
|
||||
self.assertAllClose(mean, moving_mean, atol=1e-2)
|
||||
self.assertAllClose(variance, moving_var, atol=1e-2)
|
||||
|
||||
# Test inference with placeholder learning phase.
|
||||
np_output = sess.run(outputs, feed_dict={training: False})
|
||||
|
||||
# Verify that the axis is normalized during inference.
|
||||
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||
|
||||
def testNegativeAxis(self):
|
||||
epsilon = 1e-3
|
||||
bn = normalization_layers.BatchNormalization(
|
||||
|
@ -14,7 +14,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "batch_normalization"
|
||||
argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\'], "
|
||||
argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "conv1d"
|
||||
|
Loading…
Reference in New Issue
Block a user