Now that moments are numerically stable by default, not need to use moving_mean.

Change: 144251474
This commit is contained in:
Sergio Guadarrama 2017-01-11 14:37:45 -08:00 committed by TensorFlower Gardener
parent 6426c4efd8
commit 37af1b8790
4 changed files with 109 additions and 110 deletions

View File

@ -173,9 +173,10 @@ def _fused_batch_norm(
`data_format` is `NHWC` and the second dimension if `data_format` is
`NCHW`.
decay: decay for the moving average. Reasonable values for `decay` are close
to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. Lower
`decay` value (recommend trying `decay`=0.9) if model experiences reasonably
good training performance but poor validation and/or test performance.
to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc.
Lower `decay` value (recommend trying `decay`=0.9) if model experiences
reasonably good training performance but poor validation and/or test
performance.
center: If True, subtract `beta`. If False, `beta` is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
not used. When the next layer is linear (also e.g. `nn.relu`), this can be
@ -630,16 +631,12 @@ def batch_norm(
if need_moments:
# Calculate the moments based on the individual batch.
if batch_weights is None:
# Use a copy of moving_mean as a shift to compute more reliable moments.
shift = math_ops.add(moving_mean, 0)
if data_format == DATA_FORMAT_NCHW:
shift = array_ops.reshape(shift, params_shape_broadcast)
mean, variance = nn.moments(inputs, moments_axes, shift=shift,
keep_dims=True)
mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
mean = array_ops.reshape(mean, [-1])
variance = array_ops.reshape(variance, [-1])
else:
mean, variance = nn.moments(inputs, moments_axes, shift=shift)
mean, variance = nn.moments(inputs, moments_axes)
else:
if data_format == DATA_FORMAT_NCHW:
mean, variance = nn.weighted_moments(inputs, moments_axes,
@ -1383,7 +1380,7 @@ def fully_connected(inputs,
Raises:
ValueError: if x has rank less than 2 or if its last dimension is not set.
"""
if not (isinstance(num_outputs, six.integer_types)):
if not isinstance(num_outputs, six.integer_types):
raise ValueError('num_outputs should be int or long, got %s.', num_outputs)
layer_variable_getter = _build_variable_getter({'bias': 'biases'})

View File

@ -2356,7 +2356,7 @@ class BatchNormTest(test.TestCase):
else:
image_shape = (batch_size, channels, height, width)
axis = (0, 2, 3)
image_values = np.random.rand(*image_shape) + 2
image_values = np.random.rand(*image_shape) + 256
expected_mean = np.mean(image_values, axis=axis)
expected_var = np.var(image_values, axis=axis)
if fused:
@ -2393,9 +2393,9 @@ class BatchNormTest(test.TestCase):
# The outputs should be close to 0.0 mean and 1.0 variance
self.assertAllClose(
np.mean(
np_output, axis=axis), [0] * channels, rtol=0.1, atol=0.1)
np_output, axis=axis), [0] * channels, rtol=0.001, atol=0.001)
self.assertAllClose(
np.var(np_output, axis=axis), [1] * channels, rtol=0.1, atol=0.1)
np.var(np_output, axis=axis), [1] * channels, rtol=0.01, atol=0.01)
# The gradients should change slowly while updating moving_mean.
max_diff = np.max(np.abs(images_gradients_value - new_images_gradients))
self.assertGreaterEqual(max_diff, 0.0)
@ -2558,25 +2558,29 @@ class LayerNormTest(test.TestCase):
# output_train and output_eval should be the same.
self.assertAllClose(sess.run([output_train]), sess.run([output_eval]))
def doOutputTest(self, input_shape):
with self.test_session() as sess:
input_values = np.random.rand(*input_shape)
inputs = constant_op.constant(
input_values, shape=input_shape, dtype=dtypes.float32)
output_op = _layers.layer_norm(inputs, scope='LN')
# Initialize all variables
sess.run(variables_lib.global_variables_initializer())
# The mean and variance of the output should be close to 0 and 1
# respectively.
moments_axis = tuple([i for i in range(1, len(input_shape))])
outputs = sess.run(output_op)
expected_mean = np.zeros(input_shape[0])
expected_var = np.ones(input_shape[0])
mean = np.mean(outputs, axis=moments_axis)
var = np.var(outputs, axis=moments_axis)
tol = 1e-5
self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
self.assertAllClose(var, expected_var, rtol=tol, atol=tol)
def doOutputTest(self, input_shape, tol=1e-3):
for mu in [0.0, 1e2]:
for sigma in [1.0, 0.1]:
input_values = np.random.rand(*input_shape) * sigma + mu
expected_mean = np.zeros(input_shape[0])
expected_var = np.ones(input_shape[0])
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
inputs = constant_op.constant(input_values, shape=input_shape,
dtype=dtypes.float32)
output_op = _layers.layer_norm(inputs, scope='LN')
# Initialize all variables
sess.run(variables_lib.global_variables_initializer())
# The mean and variance of the output should be close to 0 and 1
# respectively.
moments_axis = tuple([i for i in range(1, len(input_shape))])
outputs = sess.run(output_op)
# Make sure that there are no NaNs
self.assertFalse(np.isnan(outputs).any())
mean = np.mean(outputs, axis=moments_axis)
var = np.var(outputs, axis=moments_axis)
self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
self.assertAllClose(var, expected_var, rtol=tol, atol=tol)
def testOutput2DInput(self):
self.doOutputTest((10, 300))
@ -2584,6 +2588,12 @@ class LayerNormTest(test.TestCase):
def testOutput4DInput(self):
self.doOutputTest((100, 10, 10, 3))
def testOutputSmallInput(self):
self.doOutputTest((10, 10, 10, 30))
def testOutputBigInput(self):
self.doOutputTest((1, 100, 100, 1))
class MaxPool2DTest(test.TestCase):

View File

@ -178,16 +178,13 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access
broadcast_gamma = None
if training_value is not False:
# Use a copy of moving_mean as a shift to compute more reliable moments.
shift = math_ops.add(self.moving_mean, 0)
if needs_broadcasting:
shift = array_ops.reshape(shift, broadcast_shape)
broadcast_mean, broadcast_variance = nn.moments(
inputs, reduction_axes, shift=shift, keep_dims=True)
inputs, reduction_axes, keep_dims=True)
mean = array_ops.reshape(broadcast_mean, [-1])
variance = array_ops.reshape(broadcast_variance, [-1])
else:
mean, variance = nn.moments(inputs, reduction_axes, shift=shift)
mean, variance = nn.moments(inputs, reduction_axes)
# Prepare updates if necessary.
if not self.updates:

View File

@ -63,16 +63,25 @@ class BNTest(test.TestCase):
bn = normalization_layers.BatchNormalization(
axis=1, epsilon=epsilon, momentum=0.9)
inputs = variables.Variable(
np.random.random((5, 4, 3)), dtype=dtypes.float32)
np.random.random((5, 4, 3)) + 100, dtype=dtypes.float32)
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
with self.test_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 4, 1))
np_beta = np.reshape(np_beta, (1, 4, 1))
for _ in range(100):
np_output, _, _ = sess.run([outputs] + bn.updates,
feed_dict={training: True})
# Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
@ -83,14 +92,6 @@ class BNTest(test.TestCase):
self.assertAllClose(mean, moving_mean, atol=1e-2)
self.assertAllClose(variance, moving_var, atol=1e-2)
# Verify that the axis is normalized during training.
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 4, 1))
np_beta = np.reshape(np_beta, (1, 4, 1))
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Test inference with placeholder learning phase.
np_output = sess.run(outputs, feed_dict={training: False})
@ -104,16 +105,23 @@ class BNTest(test.TestCase):
bn = normalization_layers.BatchNormalization(
axis=2, epsilon=epsilon, momentum=0.9)
inputs = variables.Variable(
np.random.random((5, 4, 3)), dtype=dtypes.float32)
np.random.random((5, 4, 3)) + 100, dtype=dtypes.float32)
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
with self.test_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 3))
np_beta = np.reshape(np_beta, (1, 1, 3))
for _ in range(100):
np_output, _, _ = sess.run([outputs] + bn.updates,
feed_dict={training: True})
# Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
@ -124,14 +132,6 @@ class BNTest(test.TestCase):
self.assertAllClose(mean, moving_mean, atol=1e-2)
self.assertAllClose(variance, moving_var, atol=1e-2)
# Verify that the axis is normalized during training.
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 3))
np_beta = np.reshape(np_beta, (1, 1, 3))
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Test inference with placeholder learning phase.
np_output = sess.run(outputs, feed_dict={training: False})
@ -145,16 +145,23 @@ class BNTest(test.TestCase):
bn = normalization_layers.BatchNormalization(
axis=1, epsilon=epsilon, momentum=0.9)
inputs = variables.Variable(
np.random.random((5, 4, 3, 6)), dtype=dtypes.float32)
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
with self.test_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
np_beta = np.reshape(np_beta, (1, 4, 1, 1))
for _ in range(100):
np_output, _, _ = sess.run([outputs] + bn.updates,
feed_dict={training: True})
# Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
@ -165,14 +172,6 @@ class BNTest(test.TestCase):
self.assertAllClose(mean, moving_mean, atol=1e-2)
self.assertAllClose(variance, moving_var, atol=1e-2)
# Verify that the axis is normalized during training.
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
np_beta = np.reshape(np_beta, (1, 4, 1, 1))
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Test inference with placeholder learning phase.
np_output = sess.run(outputs, feed_dict={training: False})
@ -186,16 +185,23 @@ class BNTest(test.TestCase):
bn = normalization_layers.BatchNormalization(
axis=2, epsilon=epsilon, momentum=0.9)
inputs = variables.Variable(
np.random.random((5, 4, 3, 6)), dtype=dtypes.float32)
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
with self.test_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 3, 1))
np_beta = np.reshape(np_beta, (1, 1, 3, 1))
for _ in range(100):
np_output, _, _ = sess.run([outputs] + bn.updates,
feed_dict={training: True})
# Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
@ -206,14 +212,6 @@ class BNTest(test.TestCase):
self.assertAllClose(mean, moving_mean, atol=1e-2)
self.assertAllClose(variance, moving_var, atol=1e-2)
# Verify that the axis is normalized during training.
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 3, 1))
np_beta = np.reshape(np_beta, (1, 1, 3, 1))
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Test inference with placeholder learning phase.
np_output = sess.run(outputs, feed_dict={training: False})
@ -227,16 +225,23 @@ class BNTest(test.TestCase):
bn = normalization_layers.BatchNormalization(
axis=3, epsilon=epsilon, momentum=0.9)
inputs = variables.Variable(
np.random.random((5, 4, 3, 6)), dtype=dtypes.float32)
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
with self.test_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
for _ in range(100):
np_output, _, _ = sess.run([outputs] + bn.updates,
feed_dict={training: True})
# Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
@ -247,14 +252,6 @@ class BNTest(test.TestCase):
self.assertAllClose(mean, moving_mean, atol=1e-2)
self.assertAllClose(variance, moving_var, atol=1e-2)
# Verify that the axis is normalized during training.
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Test inference with placeholder learning phase.
np_output = sess.run(outputs, feed_dict={training: False})
@ -268,17 +265,25 @@ class BNTest(test.TestCase):
bn = normalization_layers.BatchNormalization(
axis=-1, epsilon=epsilon, momentum=0.9)
inputs = variables.Variable(
np.random.random((5, 4, 3, 6)), dtype=dtypes.float32)
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
with self.test_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
for _ in range(100):
np_output, _, _ = sess.run([outputs] + bn.updates,
feed_dict={training: True})
# Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
np_inputs = sess.run(inputs)
@ -288,14 +293,6 @@ class BNTest(test.TestCase):
self.assertAllClose(mean, moving_mean, atol=1e-2)
self.assertAllClose(variance, moving_var, atol=1e-2)
# Verify that the axis is normalized during training.
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Test inference with placeholder learning phase.
np_output = sess.run(outputs, feed_dict={training: False})
@ -309,15 +306,22 @@ class BNTest(test.TestCase):
bn = normalization_layers.BatchNormalization(
axis=-1, epsilon=epsilon, momentum=0.9)
inputs = variables.Variable(
np.random.random((5, 4, 3, 6)), dtype=dtypes.float32)
np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32)
outputs_training = bn.apply(inputs, training=True)
outputs_infer = bn.apply(inputs, training=False)
with self.test_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
for _ in range(100):
np_output, _, _ = sess.run([outputs_training] + bn.updates)
# Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=2)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
@ -328,14 +332,6 @@ class BNTest(test.TestCase):
self.assertAllClose(mean, moving_mean, atol=1e-2)
self.assertAllClose(variance, moving_var, atol=1e-2)
# Verify that the axis is normalized during training.
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Test inference with placeholder learning phase.
np_output = sess.run(outputs_infer)
@ -367,9 +363,16 @@ class BNTest(test.TestCase):
with self.test_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([gamma, beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
for _ in range(100):
np_output, _, _ = sess.run([outputs] + updates,
feed_dict={training: True})
# Verify that the axis is normalized during training.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
np_moving_mean, np_moving_var = sess.run([moving_mean, moving_variance])
@ -380,14 +383,6 @@ class BNTest(test.TestCase):
self.assertAllClose(np_mean, np_moving_mean, atol=1e-2)
self.assertAllClose(np_variance, np_moving_var, atol=1e-2)
# Verify that the axis is normalized during training.
np_gamma, np_beta = sess.run([gamma, beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Test inference with placeholder learning phase.
np_output = sess.run(outputs, feed_dict={training: False})
@ -448,7 +443,7 @@ class BNTest(test.TestCase):
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=2)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Test inference with placeholder learning phase.
@ -456,7 +451,7 @@ class BNTest(test.TestCase):
# Verify that the axis is normalized during inference.
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1)
self.assertAlmostEqual(np.mean(normed_np_output), 0., places=2)
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
def testNoCenter(self):