Fix moments calculation to never result in negative variance and avoid
doing extra work when shift = None. With the current calculation shift is ignored. PiperOrigin-RevId: 161003939
This commit is contained in:
parent
70804d820b
commit
eccd162119
@ -580,15 +580,16 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
|
|||||||
return (mean, variance)
|
return (mean, variance)
|
||||||
|
|
||||||
|
|
||||||
def moments(x, axes, shift=None, name=None, keep_dims=False):
|
def moments(x, axes,
|
||||||
|
shift=None, # pylint: disable=unused-argument
|
||||||
|
name=None, keep_dims=False):
|
||||||
"""Calculate the mean and variance of `x`.
|
"""Calculate the mean and variance of `x`.
|
||||||
|
|
||||||
The mean and variance are calculated by aggregating the contents of `x`
|
The mean and variance are calculated by aggregating the contents of `x`
|
||||||
across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean
|
across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean
|
||||||
and variance of a vector.
|
and variance of a vector.
|
||||||
|
|
||||||
Note: for numerical stability, when shift=None, the true mean
|
Note: shift is currently not used, the true mean is computed and used.
|
||||||
would be computed and used as shift.
|
|
||||||
|
|
||||||
When using these moments for batch normalization (see
|
When using these moments for batch normalization (see
|
||||||
`tf.nn.batch_normalization`):
|
`tf.nn.batch_normalization`):
|
||||||
@ -601,35 +602,26 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
|
|||||||
x: A `Tensor`.
|
x: A `Tensor`.
|
||||||
axes: Array of ints. Axes along which to compute mean and
|
axes: Array of ints. Axes along which to compute mean and
|
||||||
variance.
|
variance.
|
||||||
shift: A `Tensor` containing the value by which to shift the data for
|
shift: Not used in the current implementation
|
||||||
numerical stability, or `None` in which case the true mean of the data is
|
|
||||||
used as shift. A shift close to the true mean provides the most
|
|
||||||
numerically stable results.
|
|
||||||
name: Name used to scope the operations that compute the moments.
|
name: Name used to scope the operations that compute the moments.
|
||||||
keep_dims: produce moments with the same dimensionality as the input.
|
keep_dims: produce moments with the same dimensionality as the input.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Two `Tensor` objects: `mean` and `variance`.
|
Two `Tensor` objects: `mean` and `variance`.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "moments", [x, axes, shift]):
|
with ops.name_scope(name, "moments", [x, axes]):
|
||||||
# The dynamic range of fp16 is too limited to support the collection of
|
# The dynamic range of fp16 is too limited to support the collection of
|
||||||
# sufficient statistics. As a workaround we simply perform the operations
|
# sufficient statistics. As a workaround we simply perform the operations
|
||||||
# on 32-bit floats before converting the mean and variance back to fp16
|
# on 32-bit floats before converting the mean and variance back to fp16
|
||||||
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
|
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
|
||||||
if shift is None:
|
# Compute true mean while keeping the dims for proper broadcasting.
|
||||||
# Compute true mean while keeping the dims for proper broadcasting.
|
mean = math_ops.reduce_mean(y, axes, keep_dims=True, name="mean")
|
||||||
shift = array_ops.stop_gradient(
|
# sample variance, not unbiased variance
|
||||||
math_ops.reduce_mean(y, axes, keep_dims=True))
|
variance = math_ops.reduce_mean(
|
||||||
else:
|
math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
|
||||||
shift = math_ops.cast(shift, y.dtype)
|
axes,
|
||||||
shifted_mean = math_ops.reduce_mean(
|
keep_dims=True,
|
||||||
math_ops.subtract(y, shift), axes, keep_dims=True, name="shifted_mean")
|
|
||||||
variance = math_ops.subtract(
|
|
||||||
math_ops.reduce_mean(
|
|
||||||
math_ops.squared_difference(y, shift), axes, keep_dims=True),
|
|
||||||
math_ops.square(shifted_mean),
|
|
||||||
name="variance")
|
name="variance")
|
||||||
mean = math_ops.add(shifted_mean, shift, name="mean")
|
|
||||||
if not keep_dims:
|
if not keep_dims:
|
||||||
mean = array_ops.squeeze(mean, axes)
|
mean = array_ops.squeeze(mean, axes)
|
||||||
variance = array_ops.squeeze(variance, axes)
|
variance = array_ops.squeeze(variance, axes)
|
||||||
|
@ -877,28 +877,6 @@ class MomentsTest(test_lib.TestCase):
|
|||||||
def testOutput4DInput123(self):
|
def testOutput4DInput123(self):
|
||||||
self.doOutputTest((10, 10, 10, 30), (1, 2, 3))
|
self.doOutputTest((10, 10, 10, 30), (1, 2, 3))
|
||||||
|
|
||||||
def testUnstableOutputShiftNone(self):
|
|
||||||
input_shape = (10, 300)
|
|
||||||
moments_axes = (0, 1)
|
|
||||||
mu, sigma = 1e3, 0.1
|
|
||||||
tol = 1e-3
|
|
||||||
input_values = np.random.rand(*input_shape) * sigma + mu
|
|
||||||
expected_mean = np.mean(input_values, axis=moments_axes)
|
|
||||||
expected_var = np.var(input_values, axis=moments_axes)
|
|
||||||
|
|
||||||
with self.test_session() as sess:
|
|
||||||
inputs = constant_op.constant(
|
|
||||||
input_values, shape=input_shape, dtype=dtypes.float32)
|
|
||||||
mean, variance = nn_impl.moments(inputs, moments_axes, shift=0.0)
|
|
||||||
|
|
||||||
[mean, variance] = sess.run([mean, variance])
|
|
||||||
# Make sure that there are no NaNs
|
|
||||||
self.assertFalse(np.isnan(mean).any())
|
|
||||||
self.assertFalse(np.isnan(variance).any())
|
|
||||||
self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
|
|
||||||
# The variance is unstable
|
|
||||||
self.assertGreater(np.abs(variance - expected_var), 0.1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_lib.main()
|
test_lib.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user