Add fused batch norm to tf.layers.
PiperOrigin-RevId: 157893874
This commit is contained in:
parent
f37d0ea47b
commit
675d36be0d
@ -3555,13 +3555,11 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
cuda_py_test(
|
||||||
name = "layers_normalization_test",
|
name = "layers_normalization_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["layers/normalization_test.py"],
|
srcs = ["layers/normalization_test.py"],
|
||||||
main = "layers/normalization_test.py",
|
additional_deps = [
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
":array_ops",
|
":array_ops",
|
||||||
":client_testlib",
|
":client_testlib",
|
||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
@ -3571,6 +3569,7 @@ py_test(
|
|||||||
":variables",
|
":variables",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
|
main = "layers/normalization_test.py",
|
||||||
)
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
@ -66,9 +66,6 @@ class BatchNormalization(base.Layer):
|
|||||||
moving_variance_initializer: Initializer for the moving variance.
|
moving_variance_initializer: Initializer for the moving variance.
|
||||||
beta_regularizer: Optional regularizer for the beta weight.
|
beta_regularizer: Optional regularizer for the beta weight.
|
||||||
gamma_regularizer: Optional regularizer for the gamma weight.
|
gamma_regularizer: Optional regularizer for the gamma weight.
|
||||||
trainable: Boolean, if `True` also add variables to the graph collection
|
|
||||||
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
|
||||||
name: A string, the name of the layer.
|
|
||||||
renorm: Whether to use Batch Renormalization
|
renorm: Whether to use Batch Renormalization
|
||||||
(https://arxiv.org/abs/1702.03275). This adds extra variables during
|
(https://arxiv.org/abs/1702.03275). This adds extra variables during
|
||||||
training. The inference is the same for either value of this parameter.
|
training. The inference is the same for either value of this parameter.
|
||||||
@ -82,6 +79,11 @@ class BatchNormalization(base.Layer):
|
|||||||
and should be neither too small (which would add noise) nor too large
|
and should be neither too small (which would add noise) nor too large
|
||||||
(which would give stale estimates). Note that `momentum` is still applied
|
(which would give stale estimates). Note that `momentum` is still applied
|
||||||
to get the means and variances for inference.
|
to get the means and variances for inference.
|
||||||
|
fused: if `True`, use a faster, fused implementation based on
|
||||||
|
nn.fused_batch_norm. If `None`, use the fused implementation if possible.
|
||||||
|
trainable: Boolean, if `True` also add variables to the graph collection
|
||||||
|
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
||||||
|
name: A string, the name of the layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -99,6 +101,7 @@ class BatchNormalization(base.Layer):
|
|||||||
renorm=False,
|
renorm=False,
|
||||||
renorm_clipping=None,
|
renorm_clipping=None,
|
||||||
renorm_momentum=0.99,
|
renorm_momentum=0.99,
|
||||||
|
fused=False,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
name=None,
|
name=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -116,6 +119,10 @@ class BatchNormalization(base.Layer):
|
|||||||
self.beta_regularizer = beta_regularizer
|
self.beta_regularizer = beta_regularizer
|
||||||
self.gamma_regularizer = gamma_regularizer
|
self.gamma_regularizer = gamma_regularizer
|
||||||
self.renorm = renorm
|
self.renorm = renorm
|
||||||
|
self.fused = fused
|
||||||
|
if self.fused and renorm:
|
||||||
|
raise ValueError(
|
||||||
|
'Batch renorm is currently not supported with fused batch norm.')
|
||||||
if renorm:
|
if renorm:
|
||||||
renorm_clipping = renorm_clipping or {}
|
renorm_clipping = renorm_clipping or {}
|
||||||
keys = ['rmax', 'rmin', 'dmax']
|
keys = ['rmax', 'rmin', 'dmax']
|
||||||
@ -130,6 +137,13 @@ class BatchNormalization(base.Layer):
|
|||||||
if not input_shape.ndims:
|
if not input_shape.ndims:
|
||||||
raise ValueError('Input has undefined rank:', input_shape)
|
raise ValueError('Input has undefined rank:', input_shape)
|
||||||
ndim = len(input_shape)
|
ndim = len(input_shape)
|
||||||
|
# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
|
||||||
|
# output back to its original shape accordingly.
|
||||||
|
if self.fused and ndim != 4:
|
||||||
|
raise ValueError(
|
||||||
|
'Only 4D inputs are currently supported with fused batch norm. '
|
||||||
|
'Consider reshaping the input to 4D and reshape the output back '
|
||||||
|
'to its original shape. Got input rank: ', ndim)
|
||||||
if self.axis < 0:
|
if self.axis < 0:
|
||||||
axis = ndim + self.axis
|
axis = ndim + self.axis
|
||||||
else:
|
else:
|
||||||
@ -137,6 +151,20 @@ class BatchNormalization(base.Layer):
|
|||||||
if axis < 0 or axis >= ndim:
|
if axis < 0 or axis >= ndim:
|
||||||
raise ValueError('Value of `axis` argument ' + str(self.axis) +
|
raise ValueError('Value of `axis` argument ' + str(self.axis) +
|
||||||
' is out of range for input with rank ' + str(ndim))
|
' is out of range for input with rank ' + str(ndim))
|
||||||
|
|
||||||
|
if self.fused is None:
|
||||||
|
self.fused = not self.renorm and ndim == 4 and axis in [1, 3]
|
||||||
|
|
||||||
|
if self.fused:
|
||||||
|
if axis == 1:
|
||||||
|
self._data_format = 'NCHW'
|
||||||
|
elif axis == 3:
|
||||||
|
self._data_format = 'NHWC'
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Only axis 1 and 3 are currently supported dimensions for '
|
||||||
|
'fused batch norm. Got `axis` dimension: ', axis)
|
||||||
|
|
||||||
param_dim = input_shape[axis]
|
param_dim = input_shape[axis]
|
||||||
if not param_dim.value:
|
if not param_dim.value:
|
||||||
raise ValueError('Input has undefined `axis` dimension. Input shape: ',
|
raise ValueError('Input has undefined `axis` dimension. Input shape: ',
|
||||||
@ -152,6 +180,8 @@ class BatchNormalization(base.Layer):
|
|||||||
trainable=True)
|
trainable=True)
|
||||||
else:
|
else:
|
||||||
self.beta = None
|
self.beta = None
|
||||||
|
if self.fused:
|
||||||
|
self._beta_const = array_ops.constant(0.0, shape=(param_dim,))
|
||||||
if self.scale:
|
if self.scale:
|
||||||
self.gamma = self.add_variable(name='gamma',
|
self.gamma = self.add_variable(name='gamma',
|
||||||
shape=(param_dim,),
|
shape=(param_dim,),
|
||||||
@ -160,6 +190,8 @@ class BatchNormalization(base.Layer):
|
|||||||
trainable=True)
|
trainable=True)
|
||||||
else:
|
else:
|
||||||
self.gamma = None
|
self.gamma = None
|
||||||
|
if self.fused:
|
||||||
|
self._gamma_const = array_ops.constant(1.0, shape=(param_dim,))
|
||||||
|
|
||||||
# Disable variable partitioning when creating the moving mean and variance
|
# Disable variable partitioning when creating the moving mean and variance
|
||||||
partitioner = self._scope.partitioner
|
partitioner = self._scope.partitioner
|
||||||
@ -205,6 +237,45 @@ class BatchNormalization(base.Layer):
|
|||||||
self._scope.set_partitioner(partitioner)
|
self._scope.set_partitioner(partitioner)
|
||||||
self.built = True
|
self.built = True
|
||||||
|
|
||||||
|
def _fused_batch_norm(self, inputs, training):
|
||||||
|
"""Returns the output of fused batch norm."""
|
||||||
|
beta = self.beta if self.center else self._beta_const
|
||||||
|
gamma = self.gamma if self.scale else self._gamma_const
|
||||||
|
|
||||||
|
def _fused_batch_norm_training():
|
||||||
|
return nn.fused_batch_norm(
|
||||||
|
inputs,
|
||||||
|
gamma,
|
||||||
|
beta,
|
||||||
|
epsilon=self.epsilon,
|
||||||
|
data_format=self._data_format)
|
||||||
|
|
||||||
|
def _fused_batch_norm_inference():
|
||||||
|
return nn.fused_batch_norm(
|
||||||
|
inputs,
|
||||||
|
gamma,
|
||||||
|
beta,
|
||||||
|
mean=self.moving_mean,
|
||||||
|
variance=self.moving_variance,
|
||||||
|
epsilon=self.epsilon,
|
||||||
|
is_training=False,
|
||||||
|
data_format=self._data_format)
|
||||||
|
|
||||||
|
output, mean, variance = utils.smart_cond(
|
||||||
|
training, _fused_batch_norm_training, _fused_batch_norm_inference)
|
||||||
|
|
||||||
|
training_value = utils.constant_value(training)
|
||||||
|
if training_value is not False:
|
||||||
|
decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
|
||||||
|
mean_update = moving_averages.assign_moving_average(
|
||||||
|
self.moving_mean, mean, decay, zero_debias=False)
|
||||||
|
variance_update = moving_averages.assign_moving_average(
|
||||||
|
self.moving_variance, variance, decay, zero_debias=False)
|
||||||
|
self.add_update(mean_update, inputs=inputs)
|
||||||
|
self.add_update(variance_update, inputs=inputs)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def _renorm_correction_and_moments(self, mean, variance, training):
|
def _renorm_correction_and_moments(self, mean, variance, training):
|
||||||
"""Returns the correction and update values for renorm."""
|
"""Returns the correction and update values for renorm."""
|
||||||
stddev = math_ops.sqrt(variance + self.epsilon)
|
stddev = math_ops.sqrt(variance + self.epsilon)
|
||||||
@ -265,6 +336,9 @@ class BatchNormalization(base.Layer):
|
|||||||
return (r, d, new_mean, new_variance)
|
return (r, d, new_mean, new_variance)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
|
if self.fused:
|
||||||
|
return self._fused_batch_norm(inputs, training=training)
|
||||||
|
|
||||||
# First, compute the axes along which to reduce the mean / variance,
|
# First, compute the axes along which to reduce the mean / variance,
|
||||||
# as well as the broadcast shape to be used for all parameters.
|
# as well as the broadcast shape to be used for all parameters.
|
||||||
input_shape = inputs.get_shape()
|
input_shape = inputs.get_shape()
|
||||||
@ -353,7 +427,8 @@ def batch_normalization(inputs,
|
|||||||
reuse=None,
|
reuse=None,
|
||||||
renorm=False,
|
renorm=False,
|
||||||
renorm_clipping=None,
|
renorm_clipping=None,
|
||||||
renorm_momentum=0.99):
|
renorm_momentum=0.99,
|
||||||
|
fused=False):
|
||||||
"""Functional interface for the batch normalization layer.
|
"""Functional interface for the batch normalization layer.
|
||||||
|
|
||||||
Reference: http://arxiv.org/abs/1502.03167
|
Reference: http://arxiv.org/abs/1502.03167
|
||||||
@ -415,6 +490,8 @@ def batch_normalization(inputs,
|
|||||||
and should be neither too small (which would add noise) nor too large
|
and should be neither too small (which would add noise) nor too large
|
||||||
(which would give stale estimates). Note that `momentum` is still applied
|
(which would give stale estimates). Note that `momentum` is still applied
|
||||||
to get the means and variances for inference.
|
to get the means and variances for inference.
|
||||||
|
fused: if `True`, use a faster, fused implementation based on
|
||||||
|
nn.fused_batch_norm. If `None`, use the fused implementation if possible.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output tensor.
|
Output tensor.
|
||||||
@ -431,10 +508,11 @@ def batch_normalization(inputs,
|
|||||||
moving_variance_initializer=moving_variance_initializer,
|
moving_variance_initializer=moving_variance_initializer,
|
||||||
beta_regularizer=beta_regularizer,
|
beta_regularizer=beta_regularizer,
|
||||||
gamma_regularizer=gamma_regularizer,
|
gamma_regularizer=gamma_regularizer,
|
||||||
trainable=trainable,
|
|
||||||
renorm=renorm,
|
renorm=renorm,
|
||||||
renorm_clipping=renorm_clipping,
|
renorm_clipping=renorm_clipping,
|
||||||
renorm_momentum=renorm_momentum,
|
renorm_momentum=renorm_momentum,
|
||||||
|
fused=fused,
|
||||||
|
trainable=trainable,
|
||||||
name=name,
|
name=name,
|
||||||
_reuse=reuse,
|
_reuse=reuse,
|
||||||
_scope=name)
|
_scope=name)
|
||||||
|
@ -262,6 +262,87 @@ class BNTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||||
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||||
|
|
||||||
|
def test4DInputAxis3Fused(self):
|
||||||
|
epsilon = 1e-3
|
||||||
|
bn = normalization_layers.BatchNormalization(
|
||||||
|
axis=3, epsilon=epsilon, momentum=0.9, fused=True)
|
||||||
|
inputs = variables.Variable(
|
||||||
|
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
|
||||||
|
training = array_ops.placeholder(dtype='bool')
|
||||||
|
outputs = bn.apply(inputs, training=training)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
# Test training with placeholder learning phase.
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
|
||||||
|
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
|
||||||
|
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
|
||||||
|
for _ in range(100):
|
||||||
|
np_output, _, _ = sess.run(
|
||||||
|
[outputs] + bn.updates, feed_dict={training: True})
|
||||||
|
# Verify that the axis is normalized during training.
|
||||||
|
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||||
|
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||||
|
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||||
|
|
||||||
|
# Verify that the statistics are updated during training.
|
||||||
|
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
|
||||||
|
np_inputs = sess.run(inputs)
|
||||||
|
mean = np.mean(np_inputs, axis=(0, 1, 2))
|
||||||
|
std = np.std(np_inputs, axis=(0, 1, 2))
|
||||||
|
variance = np.square(std)
|
||||||
|
self.assertAllClose(mean, moving_mean, atol=1e-2)
|
||||||
|
self.assertAllClose(variance, moving_var, atol=1e-2)
|
||||||
|
|
||||||
|
# Test inference with placeholder learning phase.
|
||||||
|
np_output = sess.run(outputs, feed_dict={training: False})
|
||||||
|
|
||||||
|
# Verify that the axis is normalized during inference.
|
||||||
|
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||||
|
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||||
|
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||||
|
|
||||||
|
def test4DInputAxis1Fused(self):
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
epsilon = 1e-3
|
||||||
|
bn = normalization_layers.BatchNormalization(
|
||||||
|
axis=1, epsilon=epsilon, momentum=0.9, fused=True)
|
||||||
|
inputs = variables.Variable(
|
||||||
|
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
|
||||||
|
training = array_ops.placeholder(dtype='bool')
|
||||||
|
outputs = bn.apply(inputs, training=training)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
# Test training with placeholder learning phase.
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
|
||||||
|
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
|
||||||
|
np_beta = np.reshape(np_beta, (1, 4, 1, 1))
|
||||||
|
for _ in range(100):
|
||||||
|
np_output, _, _ = sess.run(
|
||||||
|
[outputs] + bn.updates, feed_dict={training: True})
|
||||||
|
# Verify that the axis is normalized during training.
|
||||||
|
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||||
|
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||||
|
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||||
|
|
||||||
|
# Verify that the statistics are updated during training.
|
||||||
|
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
|
||||||
|
np_inputs = sess.run(inputs)
|
||||||
|
mean = np.mean(np_inputs, axis=(0, 2, 3))
|
||||||
|
std = np.std(np_inputs, axis=(0, 2, 3))
|
||||||
|
variance = np.square(std)
|
||||||
|
self.assertAllClose(mean, moving_mean, atol=1e-2)
|
||||||
|
self.assertAllClose(variance, moving_var, atol=1e-2)
|
||||||
|
|
||||||
|
# Test inference with placeholder learning phase.
|
||||||
|
np_output = sess.run(outputs, feed_dict={training: False})
|
||||||
|
|
||||||
|
# Verify that the axis is normalized during inference.
|
||||||
|
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
|
||||||
|
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
|
||||||
|
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
|
||||||
|
|
||||||
def testNegativeAxis(self):
|
def testNegativeAxis(self):
|
||||||
epsilon = 1e-3
|
epsilon = 1e-3
|
||||||
bn = normalization_layers.BatchNormalization(
|
bn = normalization_layers.BatchNormalization(
|
||||||
|
@ -14,7 +14,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "batch_normalization"
|
name: "batch_normalization"
|
||||||
argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\'], "
|
argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "conv1d"
|
name: "conv1d"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user