From eccd162119675d0bf5bc6f8e6a93dcda7ab6db4a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Jul 2017 14:14:54 -0700 Subject: [PATCH] Fix moments calculation to never result in negative variance and avoid doing extra work when shift = None. With the current calculation shift is ignored. PiperOrigin-RevId: 161003939 --- tensorflow/python/ops/nn_impl.py | 34 ++++++++++++-------------------- tensorflow/python/ops/nn_test.py | 22 --------------------- 2 files changed, 13 insertions(+), 43 deletions(-) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 7bc99ac725a..98ede2031bc 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -580,15 +580,16 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): return (mean, variance) -def moments(x, axes, shift=None, name=None, keep_dims=False): +def moments(x, axes, + shift=None, # pylint: disable=unused-argument + name=None, keep_dims=False): """Calculate the mean and variance of `x`. The mean and variance are calculated by aggregating the contents of `x` across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean and variance of a vector. - Note: for numerical stability, when shift=None, the true mean - would be computed and used as shift. + Note: shift is currently not used, the true mean is computed and used. When using these moments for batch normalization (see `tf.nn.batch_normalization`): @@ -601,35 +602,26 @@ def moments(x, axes, shift=None, name=None, keep_dims=False): x: A `Tensor`. axes: Array of ints. Axes along which to compute mean and variance. - shift: A `Tensor` containing the value by which to shift the data for - numerical stability, or `None` in which case the true mean of the data is - used as shift. A shift close to the true mean provides the most - numerically stable results. + shift: Not used in the current implementation name: Name used to scope the operations that compute the moments. keep_dims: produce moments with the same dimensionality as the input. Returns: Two `Tensor` objects: `mean` and `variance`. """ - with ops.name_scope(name, "moments", [x, axes, shift]): + with ops.name_scope(name, "moments", [x, axes]): # The dynamic range of fp16 is too limited to support the collection of # sufficient statistics. As a workaround we simply perform the operations # on 32-bit floats before converting the mean and variance back to fp16 y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x - if shift is None: - # Compute true mean while keeping the dims for proper broadcasting. - shift = array_ops.stop_gradient( - math_ops.reduce_mean(y, axes, keep_dims=True)) - else: - shift = math_ops.cast(shift, y.dtype) - shifted_mean = math_ops.reduce_mean( - math_ops.subtract(y, shift), axes, keep_dims=True, name="shifted_mean") - variance = math_ops.subtract( - math_ops.reduce_mean( - math_ops.squared_difference(y, shift), axes, keep_dims=True), - math_ops.square(shifted_mean), + # Compute true mean while keeping the dims for proper broadcasting. + mean = math_ops.reduce_mean(y, axes, keep_dims=True, name="mean") + # sample variance, not unbiased variance + variance = math_ops.reduce_mean( + math_ops.squared_difference(y, array_ops.stop_gradient(mean)), + axes, + keep_dims=True, name="variance") - mean = math_ops.add(shifted_mean, shift, name="mean") if not keep_dims: mean = array_ops.squeeze(mean, axes) variance = array_ops.squeeze(variance, axes) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index f3592941d1e..87f6f92a8a8 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -877,28 +877,6 @@ class MomentsTest(test_lib.TestCase): def testOutput4DInput123(self): self.doOutputTest((10, 10, 10, 30), (1, 2, 3)) - def testUnstableOutputShiftNone(self): - input_shape = (10, 300) - moments_axes = (0, 1) - mu, sigma = 1e3, 0.1 - tol = 1e-3 - input_values = np.random.rand(*input_shape) * sigma + mu - expected_mean = np.mean(input_values, axis=moments_axes) - expected_var = np.var(input_values, axis=moments_axes) - - with self.test_session() as sess: - inputs = constant_op.constant( - input_values, shape=input_shape, dtype=dtypes.float32) - mean, variance = nn_impl.moments(inputs, moments_axes, shift=0.0) - - [mean, variance] = sess.run([mean, variance]) - # Make sure that there are no NaNs - self.assertFalse(np.isnan(mean).any()) - self.assertFalse(np.isnan(variance).any()) - self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol) - # The variance is unstable - self.assertGreater(np.abs(variance - expected_var), 0.1) - if __name__ == "__main__": test_lib.main()