Make moments numerically stable by default. Added tests for moments.
Change: 144114955
This commit is contained in:
parent
b7fc7f3f46
commit
fdbd02c8d7
@ -580,6 +580,9 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
|
|||||||
across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean
|
across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean
|
||||||
and variance of a vector.
|
and variance of a vector.
|
||||||
|
|
||||||
|
Note: for numerical stability, when shift=None, the true mean
|
||||||
|
would be computed and used as shift.
|
||||||
|
|
||||||
When using these moments for batch normalization (see
|
When using these moments for batch normalization (see
|
||||||
`tf.nn.batch_normalization`):
|
`tf.nn.batch_normalization`):
|
||||||
|
|
||||||
@ -592,8 +595,9 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
|
|||||||
axes: Array of ints. Axes along which to compute mean and
|
axes: Array of ints. Axes along which to compute mean and
|
||||||
variance.
|
variance.
|
||||||
shift: A `Tensor` containing the value by which to shift the data for
|
shift: A `Tensor` containing the value by which to shift the data for
|
||||||
numerical stability, or `None` if no shift is to be performed. A shift
|
numerical stability, or `None` in which case the true mean of the data is
|
||||||
close to the true mean provides the most numerically stable results.
|
used as shift. A shift close to the true mean provides the most
|
||||||
|
numerically stable results.
|
||||||
name: Name used to scope the operations that compute the moments.
|
name: Name used to scope the operations that compute the moments.
|
||||||
keep_dims: produce moments with the same dimensionality as the input.
|
keep_dims: produce moments with the same dimensionality as the input.
|
||||||
|
|
||||||
@ -605,10 +609,17 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
|
|||||||
# sufficient statistics. As a workaround we simply perform the operations
|
# sufficient statistics. As a workaround we simply perform the operations
|
||||||
# on 32-bit floats before converting the mean and variance back to fp16
|
# on 32-bit floats before converting the mean and variance back to fp16
|
||||||
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
|
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
|
||||||
shift = math_ops.cast(shift, dtypes.float32) if (
|
if shift is None:
|
||||||
shift is not None and x.dtype == dtypes.float16) else shift
|
# Compute true mean while keeping the dims for proper broadcasting.
|
||||||
|
shift = array_ops.stop_gradient(
|
||||||
|
math_ops.reduce_mean(y, axes, keep_dims=True))
|
||||||
|
else:
|
||||||
|
shift = math_ops.cast(shift, y.dtype)
|
||||||
counts, m_ss, v_ss, shift = sufficient_statistics(
|
counts, m_ss, v_ss, shift = sufficient_statistics(
|
||||||
y, axes, shift=shift, keep_dims=keep_dims, name=name)
|
y, axes, shift=shift, keep_dims=keep_dims, name=name)
|
||||||
|
# Reshape shift as needed.
|
||||||
|
shift = array_ops.reshape(shift, array_ops.shape(m_ss))
|
||||||
|
shift.set_shape(m_ss.get_shape())
|
||||||
with ops.control_dependencies([counts, m_ss, v_ss]):
|
with ops.control_dependencies([counts, m_ss, v_ss]):
|
||||||
mean, variance = normalize_moments(counts, m_ss, v_ss, shift, name=name)
|
mean, variance = normalize_moments(counts, m_ss, v_ss, shift, name=name)
|
||||||
if x.dtype == dtypes.float16:
|
if x.dtype == dtypes.float16:
|
||||||
|
@ -25,6 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
|||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gradient_checker
|
from tensorflow.python.ops import gradient_checker
|
||||||
from tensorflow.python.ops import nn_impl
|
from tensorflow.python.ops import nn_impl
|
||||||
@ -791,5 +792,78 @@ class CReluTest(test_lib.TestCase):
|
|||||||
self.assertAllClose(y, z, 1e-4)
|
self.assertAllClose(y, z, 1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
class MomentsTest(test_lib.TestCase):
|
||||||
|
|
||||||
|
def doOutputTest(self, input_shape, moments_axes, tol=1e-4):
|
||||||
|
for mu in [0.0, 1.0, 1e3]:
|
||||||
|
for sigma in [1.0, 0.1]:
|
||||||
|
for keep_dims in [True, False]:
|
||||||
|
input_values = np.random.rand(*input_shape) * sigma + mu
|
||||||
|
expected_mean = np.mean(input_values, axis=moments_axes,
|
||||||
|
keepdims=keep_dims)
|
||||||
|
expected_var = np.var(input_values, axis=moments_axes,
|
||||||
|
keepdims=keep_dims)
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
with self.test_session(graph=g) as sess:
|
||||||
|
inputs = constant_op.constant(input_values,
|
||||||
|
shape=input_shape,
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
mean, variance = nn_impl.moments(inputs,
|
||||||
|
moments_axes,
|
||||||
|
keep_dims=keep_dims)
|
||||||
|
|
||||||
|
[mean, variance] = sess.run([mean, variance])
|
||||||
|
# Make sure that there are no NaNs
|
||||||
|
self.assertFalse(np.isnan(mean).any())
|
||||||
|
self.assertFalse(np.isnan(variance).any())
|
||||||
|
self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
|
||||||
|
self.assertAllClose(variance, expected_var, rtol=tol, atol=tol)
|
||||||
|
|
||||||
|
def testOutput2DInput0(self):
|
||||||
|
self.doOutputTest((10, 300), (0,))
|
||||||
|
|
||||||
|
def testOutput2DInput1(self):
|
||||||
|
self.doOutputTest((10, 300), (1,))
|
||||||
|
|
||||||
|
def testOutput2DInput01(self):
|
||||||
|
self.doOutputTest((10, 300), (0, 1))
|
||||||
|
|
||||||
|
def testOutput4DInput0(self):
|
||||||
|
self.doOutputTest((10, 10, 10, 30), (0,))
|
||||||
|
|
||||||
|
def testOutput4DInput1(self):
|
||||||
|
self.doOutputTest((10, 10, 10, 30), (1,))
|
||||||
|
|
||||||
|
def testOutput4DInput3(self):
|
||||||
|
self.doOutputTest((10, 10, 10, 30), (3,))
|
||||||
|
|
||||||
|
def testOutput4DInput012(self):
|
||||||
|
self.doOutputTest((10, 10, 10, 30), (0, 1, 2))
|
||||||
|
|
||||||
|
def testOutput4DInput123(self):
|
||||||
|
self.doOutputTest((10, 10, 10, 30), (1, 2, 3))
|
||||||
|
|
||||||
|
def testUnstableOutputShiftNone(self):
|
||||||
|
input_shape = (10, 300)
|
||||||
|
moments_axes = (0, 1)
|
||||||
|
mu, sigma = 1e3, 0.1
|
||||||
|
tol = 1e-3
|
||||||
|
input_values = np.random.rand(*input_shape) * sigma + mu
|
||||||
|
expected_mean = np.mean(input_values, axis=moments_axes)
|
||||||
|
expected_var = np.var(input_values, axis=moments_axes)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
inputs = constant_op.constant(input_values, shape=input_shape,
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
mean, variance = nn_impl.moments(inputs, moments_axes, shift=0.0)
|
||||||
|
|
||||||
|
[mean, variance] = sess.run([mean, variance])
|
||||||
|
# Make sure that there are no NaNs
|
||||||
|
self.assertFalse(np.isnan(mean).any())
|
||||||
|
self.assertFalse(np.isnan(variance).any())
|
||||||
|
self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
|
||||||
|
# The variance is unstable
|
||||||
|
self.assertGreater(np.abs(variance - expected_var), 0.1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_lib.main()
|
test_lib.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user