Fix moments calculation to never result in negative variance and avoid

doing extra work when shift = None.  With the current calculation shift is
ignored.

PiperOrigin-RevId: 161003939
This commit is contained in:
A. Unique TensorFlower 2017-07-05 14:14:54 -07:00 committed by TensorFlower Gardener
parent 70804d820b
commit eccd162119
2 changed files with 13 additions and 43 deletions

View File

@ -580,15 +580,16 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
return (mean, variance) return (mean, variance)
def moments(x, axes, shift=None, name=None, keep_dims=False): def moments(x, axes,
shift=None, # pylint: disable=unused-argument
name=None, keep_dims=False):
"""Calculate the mean and variance of `x`. """Calculate the mean and variance of `x`.
The mean and variance are calculated by aggregating the contents of `x` The mean and variance are calculated by aggregating the contents of `x`
across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean
and variance of a vector. and variance of a vector.
Note: for numerical stability, when shift=None, the true mean Note: shift is currently not used, the true mean is computed and used.
would be computed and used as shift.
When using these moments for batch normalization (see When using these moments for batch normalization (see
`tf.nn.batch_normalization`): `tf.nn.batch_normalization`):
@ -601,35 +602,26 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
x: A `Tensor`. x: A `Tensor`.
axes: Array of ints. Axes along which to compute mean and axes: Array of ints. Axes along which to compute mean and
variance. variance.
shift: A `Tensor` containing the value by which to shift the data for shift: Not used in the current implementation
numerical stability, or `None` in which case the true mean of the data is
used as shift. A shift close to the true mean provides the most
numerically stable results.
name: Name used to scope the operations that compute the moments. name: Name used to scope the operations that compute the moments.
keep_dims: produce moments with the same dimensionality as the input. keep_dims: produce moments with the same dimensionality as the input.
Returns: Returns:
Two `Tensor` objects: `mean` and `variance`. Two `Tensor` objects: `mean` and `variance`.
""" """
with ops.name_scope(name, "moments", [x, axes, shift]): with ops.name_scope(name, "moments", [x, axes]):
# The dynamic range of fp16 is too limited to support the collection of # The dynamic range of fp16 is too limited to support the collection of
# sufficient statistics. As a workaround we simply perform the operations # sufficient statistics. As a workaround we simply perform the operations
# on 32-bit floats before converting the mean and variance back to fp16 # on 32-bit floats before converting the mean and variance back to fp16
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
if shift is None:
# Compute true mean while keeping the dims for proper broadcasting. # Compute true mean while keeping the dims for proper broadcasting.
shift = array_ops.stop_gradient( mean = math_ops.reduce_mean(y, axes, keep_dims=True, name="mean")
math_ops.reduce_mean(y, axes, keep_dims=True)) # sample variance, not unbiased variance
else: variance = math_ops.reduce_mean(
shift = math_ops.cast(shift, y.dtype) math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
shifted_mean = math_ops.reduce_mean( axes,
math_ops.subtract(y, shift), axes, keep_dims=True, name="shifted_mean") keep_dims=True,
variance = math_ops.subtract(
math_ops.reduce_mean(
math_ops.squared_difference(y, shift), axes, keep_dims=True),
math_ops.square(shifted_mean),
name="variance") name="variance")
mean = math_ops.add(shifted_mean, shift, name="mean")
if not keep_dims: if not keep_dims:
mean = array_ops.squeeze(mean, axes) mean = array_ops.squeeze(mean, axes)
variance = array_ops.squeeze(variance, axes) variance = array_ops.squeeze(variance, axes)

View File

@ -877,28 +877,6 @@ class MomentsTest(test_lib.TestCase):
def testOutput4DInput123(self): def testOutput4DInput123(self):
self.doOutputTest((10, 10, 10, 30), (1, 2, 3)) self.doOutputTest((10, 10, 10, 30), (1, 2, 3))
def testUnstableOutputShiftNone(self):
input_shape = (10, 300)
moments_axes = (0, 1)
mu, sigma = 1e3, 0.1
tol = 1e-3
input_values = np.random.rand(*input_shape) * sigma + mu
expected_mean = np.mean(input_values, axis=moments_axes)
expected_var = np.var(input_values, axis=moments_axes)
with self.test_session() as sess:
inputs = constant_op.constant(
input_values, shape=input_shape, dtype=dtypes.float32)
mean, variance = nn_impl.moments(inputs, moments_axes, shift=0.0)
[mean, variance] = sess.run([mean, variance])
# Make sure that there are no NaNs
self.assertFalse(np.isnan(mean).any())
self.assertFalse(np.isnan(variance).any())
self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
# The variance is unstable
self.assertGreater(np.abs(variance - expected_var), 0.1)
if __name__ == "__main__": if __name__ == "__main__":
test_lib.main() test_lib.main()