diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 60499c36ec4..d8f95d98326 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -580,6 +580,9 @@ def moments(x, axes, shift=None, name=None, keep_dims=False): across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean and variance of a vector. + Note: for numerical stability, when shift=None, the true mean + would be computed and used as shift. + When using these moments for batch normalization (see `tf.nn.batch_normalization`): @@ -592,8 +595,9 @@ def moments(x, axes, shift=None, name=None, keep_dims=False): axes: Array of ints. Axes along which to compute mean and variance. shift: A `Tensor` containing the value by which to shift the data for - numerical stability, or `None` if no shift is to be performed. A shift - close to the true mean provides the most numerically stable results. + numerical stability, or `None` in which case the true mean of the data is + used as shift. A shift close to the true mean provides the most + numerically stable results. name: Name used to scope the operations that compute the moments. keep_dims: produce moments with the same dimensionality as the input. @@ -605,10 +609,17 @@ def moments(x, axes, shift=None, name=None, keep_dims=False): # sufficient statistics. As a workaround we simply perform the operations # on 32-bit floats before converting the mean and variance back to fp16 y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x - shift = math_ops.cast(shift, dtypes.float32) if ( - shift is not None and x.dtype == dtypes.float16) else shift + if shift is None: + # Compute true mean while keeping the dims for proper broadcasting. + shift = array_ops.stop_gradient( + math_ops.reduce_mean(y, axes, keep_dims=True)) + else: + shift = math_ops.cast(shift, y.dtype) counts, m_ss, v_ss, shift = sufficient_statistics( y, axes, shift=shift, keep_dims=keep_dims, name=name) + # Reshape shift as needed. + shift = array_ops.reshape(shift, array_ops.shape(m_ss)) + shift.set_shape(m_ss.get_shape()) with ops.control_dependencies([counts, m_ss, v_ss]): mean, variance = normalize_moments(counts, m_ss, v_ss, shift, name=name) if x.dtype == dtypes.float16: diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 705ca574d8e..25e7a4f45fe 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -25,6 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn_impl @@ -791,5 +792,78 @@ class CReluTest(test_lib.TestCase): self.assertAllClose(y, z, 1e-4) +class MomentsTest(test_lib.TestCase): + + def doOutputTest(self, input_shape, moments_axes, tol=1e-4): + for mu in [0.0, 1.0, 1e3]: + for sigma in [1.0, 0.1]: + for keep_dims in [True, False]: + input_values = np.random.rand(*input_shape) * sigma + mu + expected_mean = np.mean(input_values, axis=moments_axes, + keepdims=keep_dims) + expected_var = np.var(input_values, axis=moments_axes, + keepdims=keep_dims) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + inputs = constant_op.constant(input_values, + shape=input_shape, + dtype=dtypes.float32) + mean, variance = nn_impl.moments(inputs, + moments_axes, + keep_dims=keep_dims) + + [mean, variance] = sess.run([mean, variance]) + # Make sure that there are no NaNs + self.assertFalse(np.isnan(mean).any()) + self.assertFalse(np.isnan(variance).any()) + self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol) + self.assertAllClose(variance, expected_var, rtol=tol, atol=tol) + + def testOutput2DInput0(self): + self.doOutputTest((10, 300), (0,)) + + def testOutput2DInput1(self): + self.doOutputTest((10, 300), (1,)) + + def testOutput2DInput01(self): + self.doOutputTest((10, 300), (0, 1)) + + def testOutput4DInput0(self): + self.doOutputTest((10, 10, 10, 30), (0,)) + + def testOutput4DInput1(self): + self.doOutputTest((10, 10, 10, 30), (1,)) + + def testOutput4DInput3(self): + self.doOutputTest((10, 10, 10, 30), (3,)) + + def testOutput4DInput012(self): + self.doOutputTest((10, 10, 10, 30), (0, 1, 2)) + + def testOutput4DInput123(self): + self.doOutputTest((10, 10, 10, 30), (1, 2, 3)) + + def testUnstableOutputShiftNone(self): + input_shape = (10, 300) + moments_axes = (0, 1) + mu, sigma = 1e3, 0.1 + tol = 1e-3 + input_values = np.random.rand(*input_shape) * sigma + mu + expected_mean = np.mean(input_values, axis=moments_axes) + expected_var = np.var(input_values, axis=moments_axes) + + with self.test_session() as sess: + inputs = constant_op.constant(input_values, shape=input_shape, + dtype=dtypes.float32) + mean, variance = nn_impl.moments(inputs, moments_axes, shift=0.0) + + [mean, variance] = sess.run([mean, variance]) + # Make sure that there are no NaNs + self.assertFalse(np.isnan(mean).any()) + self.assertFalse(np.isnan(variance).any()) + self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol) + # The variance is unstable + self.assertGreater(np.abs(variance - expected_var), 0.1) + if __name__ == "__main__": test_lib.main()