Allow callers to provide a shift value in tf.nn.moments().

Change: 123095477
This commit is contained in:
A. Unique TensorFlower 2016-05-24 04:19:11 -08:00 committed by TensorFlower Gardener
parent dcfeb027ac
commit 4ea0733c46
3 changed files with 27 additions and 31 deletions

View File

@ -10,6 +10,9 @@
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes
* TensorBoard now displays graphs with only one data point * TensorBoard now displays graphs with only one data point
* TensorBoard now visually displays NaN values * TensorBoard now visually displays NaN values
* `tf.nn.moments()` now accepts a `shift` argument. Shifting by a good estimate
of the mean improves numerical stability. Also changes the behavior of the
`shift` argument to `tf.nn.sufficient_statistics()`.
# Release 0.8.0 # Release 0.8.0

View File

@ -587,7 +587,7 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides,
padding="VALID", name=name) padding="VALID", name=name)
def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None): def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
"""Calculate the sufficient statistics for the mean and variance of `x`. """Calculate the sufficient statistics for the mean and variance of `x`.
These sufficient statistics are computed using the one pass algorithm on These sufficient statistics are computed using the one pass algorithm on
@ -601,7 +601,9 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
Args: Args:
x: A `Tensor`. x: A `Tensor`.
axes: Array of ints. Axes along which to compute mean and variance. axes: Array of ints. Axes along which to compute mean and variance.
shift: If true, shift the data to provide more numerically stable results. shift: A `Tensor` containing the value by which to shift the data for
numerical stability, or `None` if no shift is to be performed. A shift
close to the true mean provides the most numerically stable results.
keep_dims: produce statistics with the same dimensionality as the input. keep_dims: produce statistics with the same dimensionality as the input.
name: Name used to scope the operations that compute the sufficient stats. name: Name used to scope the operations that compute the sufficient stats.
@ -612,7 +614,7 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
* the (possibly shifted) sum of squares of the elements in the array. * the (possibly shifted) sum of squares of the elements in the array.
* the shift by which the mean must be corrected or None if `shift` is False. * the shift by which the mean must be corrected or None if `shift` is False.
""" """
with ops.op_scope([x, axes], name, "sufficient_statistics"): with ops.op_scope([x, axes, shift], name, "sufficient_statistics"):
x = ops.convert_to_tensor(x, name="x") x = ops.convert_to_tensor(x, name="x")
x_shape = x.get_shape() x_shape = x.get_shape()
if x_shape.is_fully_defined(): if x_shape.is_fully_defined():
@ -635,23 +637,16 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
math_ops.reduce_prod(x_shape / m_shape), math_ops.reduce_prod(x_shape / m_shape),
x.dtype, x.dtype,
name="count") name="count")
if shift: if shift is not None:
shift_value = array_ops.slice(x, array_ops.zeros_like(m_shape), m_shape) shift = ops.convert_to_tensor(shift, name="shift")
m_ss = math_ops.sub(x, shift_value) m_ss = math_ops.sub(x, shift)
v_ss = math_ops.squared_difference(x, shift_value) v_ss = math_ops.squared_difference(x, shift)
if keep_dims: else: # no shift.
shift_value = array_ops.identity(shift_value, name="shift")
else:
shift_value = array_ops.squeeze(shift_value,
squeeze_dims=axes,
name="shift")
else: # not shift.
m_ss = x m_ss = x
v_ss = math_ops.square(x) v_ss = math_ops.square(x)
shift_value = None
m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss") m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss")
v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss") v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss")
return counts, m_ss, v_ss, shift_value return counts, m_ss, v_ss, shift
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
@ -685,7 +680,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
return (mean, variance) return (mean, variance)
def moments(x, axes, name=None, keep_dims=False): def moments(x, axes, shift=None, name=None, keep_dims=False):
"""Calculate the mean and variance of `x`. """Calculate the mean and variance of `x`.
The mean and variance are calculated by aggregating the contents of `x` The mean and variance are calculated by aggregating the contents of `x`
@ -702,15 +697,19 @@ def moments(x, axes, name=None, keep_dims=False):
x: A `Tensor`. x: A `Tensor`.
axes: array of ints. Axes along which to compute mean and axes: array of ints. Axes along which to compute mean and
variance. variance.
shift: A `Tensor` containing the value by which to shift the data for
numerical stability, or `None` if no shift is to be performed. A shift
close to the true mean provides the most numerically stable results.
keep_dims: produce moments with the same dimensionality as the input. keep_dims: produce moments with the same dimensionality as the input.
name: Name used to scope the operations that compute the moments. name: Name used to scope the operations that compute the moments.
Returns: Returns:
Two `Tensor` objects: `mean` and `variance`. Two `Tensor` objects: `mean` and `variance`.
""" """
with ops.op_scope([x, axes], name, "moments"): with ops.op_scope([x, axes, shift], name, "moments"):
counts, m_ss, v_ss, shift = sufficient_statistics(x, counts, m_ss, v_ss, shift = sufficient_statistics(x,
axes, axes,
shift=shift,
keep_dims=keep_dims, keep_dims=keep_dims,
name=name) name=name)
return normalize_moments(counts, m_ss, v_ss, shift, name=name) return normalize_moments(counts, m_ss, v_ss, shift, name=name)

View File

@ -317,16 +317,10 @@ class SufficientStatisticsTest(tf.test.TestCase):
def _npSuffStats(self, x, axes, shift, keep_dims): def _npSuffStats(self, x, axes, shift, keep_dims):
axis = tuple(axes) axis = tuple(axes)
if shift: if shift is not None:
shift_value = x[[slice(None) if i not in set(axis) else slice(0, 1) m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims)
for i in xrange(x.ndim)]] v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims)
m_ss = np.sum(x - shift_value, axis=axis, keepdims=keep_dims)
v_ss = np.sum(
(x - shift_value) * (x - shift_value),
axis=axis,
keepdims=keep_dims)
else: else:
shift_value = None
m_ss = np.sum(x, axis=axis, keepdims=keep_dims) m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims) v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
count = 1.0 count = 1.0
@ -334,8 +328,8 @@ class SufficientStatisticsTest(tf.test.TestCase):
if d in set(axes): if d in set(axes):
count *= x.shape[d] count *= x.shape[d]
if not keep_dims: if not keep_dims:
shift_value = np.squeeze(shift_value, axis=axis) shift = np.squeeze(shift, axis=axis)
return count, m_ss, v_ss, shift_value return count, m_ss, v_ss, shift
def _opSuffStats(self, x, axes, shift, keep_dims): def _opSuffStats(self, x, axes, shift, keep_dims):
return tf.nn.sufficient_statistics(x, axes, shift, keep_dims) return tf.nn.sufficient_statistics(x, axes, shift, keep_dims)
@ -375,7 +369,7 @@ class SufficientStatisticsTest(tf.test.TestCase):
def testSuffStats(self): def testSuffStats(self):
for has_shape in [True, False]: for has_shape in [True, False]:
for keep_dims in [True, False]: for keep_dims in [True, False]:
for shift in [True, False]: for shift in [None, 1.0]:
self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape) self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape) self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape) self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
@ -419,7 +413,7 @@ class NormalizeMomentsTest(tf.test.TestCase):
self.assertAllClose(npv, tfv, atol=0.000001) self.assertAllClose(npv, tfv, atol=0.000001)
def testNormalizeMoments(self): def testNormalizeMoments(self):
for shift in [True, False]: for shift in [None, 4.0]:
self._testNormalizeMoments([3], shift) self._testNormalizeMoments([3], shift)
self._testNormalizeMoments([2, 3], shift) self._testNormalizeMoments([2, 3], shift)