Allow callers to provide a shift value in tf.nn.moments().

Change: 123095477
This commit is contained in:
A. Unique TensorFlower 2016-05-24 04:19:11 -08:00 committed by TensorFlower Gardener
parent dcfeb027ac
commit 4ea0733c46
3 changed files with 27 additions and 31 deletions

View File

@ -10,6 +10,9 @@
## Bug Fixes and Other Changes
* TensorBoard now displays graphs with only one data point
* TensorBoard now visually displays NaN values
* `tf.nn.moments()` now accepts a `shift` argument. Shifting by a good estimate
of the mean improves numerical stability. Also changes the behavior of the
`shift` argument to `tf.nn.sufficient_statistics()`.
# Release 0.8.0

View File

@ -587,7 +587,7 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides,
padding="VALID", name=name)
def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
"""Calculate the sufficient statistics for the mean and variance of `x`.
These sufficient statistics are computed using the one pass algorithm on
@ -601,7 +601,9 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
Args:
x: A `Tensor`.
axes: Array of ints. Axes along which to compute mean and variance.
shift: If true, shift the data to provide more numerically stable results.
shift: A `Tensor` containing the value by which to shift the data for
numerical stability, or `None` if no shift is to be performed. A shift
close to the true mean provides the most numerically stable results.
keep_dims: produce statistics with the same dimensionality as the input.
name: Name used to scope the operations that compute the sufficient stats.
@ -612,7 +614,7 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
* the (possibly shifted) sum of squares of the elements in the array.
* the shift by which the mean must be corrected or None if `shift` is False.
"""
with ops.op_scope([x, axes], name, "sufficient_statistics"):
with ops.op_scope([x, axes, shift], name, "sufficient_statistics"):
x = ops.convert_to_tensor(x, name="x")
x_shape = x.get_shape()
if x_shape.is_fully_defined():
@ -635,23 +637,16 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
math_ops.reduce_prod(x_shape / m_shape),
x.dtype,
name="count")
if shift:
shift_value = array_ops.slice(x, array_ops.zeros_like(m_shape), m_shape)
m_ss = math_ops.sub(x, shift_value)
v_ss = math_ops.squared_difference(x, shift_value)
if keep_dims:
shift_value = array_ops.identity(shift_value, name="shift")
else:
shift_value = array_ops.squeeze(shift_value,
squeeze_dims=axes,
name="shift")
else: # not shift.
if shift is not None:
shift = ops.convert_to_tensor(shift, name="shift")
m_ss = math_ops.sub(x, shift)
v_ss = math_ops.squared_difference(x, shift)
else: # no shift.
m_ss = x
v_ss = math_ops.square(x)
shift_value = None
m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss")
v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss")
return counts, m_ss, v_ss, shift_value
return counts, m_ss, v_ss, shift
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
@ -685,7 +680,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
return (mean, variance)
def moments(x, axes, name=None, keep_dims=False):
def moments(x, axes, shift=None, name=None, keep_dims=False):
"""Calculate the mean and variance of `x`.
The mean and variance are calculated by aggregating the contents of `x`
@ -702,15 +697,19 @@ def moments(x, axes, name=None, keep_dims=False):
x: A `Tensor`.
axes: array of ints. Axes along which to compute mean and
variance.
shift: A `Tensor` containing the value by which to shift the data for
numerical stability, or `None` if no shift is to be performed. A shift
close to the true mean provides the most numerically stable results.
keep_dims: produce moments with the same dimensionality as the input.
name: Name used to scope the operations that compute the moments.
Returns:
Two `Tensor` objects: `mean` and `variance`.
"""
with ops.op_scope([x, axes], name, "moments"):
with ops.op_scope([x, axes, shift], name, "moments"):
counts, m_ss, v_ss, shift = sufficient_statistics(x,
axes,
shift=shift,
keep_dims=keep_dims,
name=name)
return normalize_moments(counts, m_ss, v_ss, shift, name=name)

View File

@ -317,16 +317,10 @@ class SufficientStatisticsTest(tf.test.TestCase):
def _npSuffStats(self, x, axes, shift, keep_dims):
axis = tuple(axes)
if shift:
shift_value = x[[slice(None) if i not in set(axis) else slice(0, 1)
for i in xrange(x.ndim)]]
m_ss = np.sum(x - shift_value, axis=axis, keepdims=keep_dims)
v_ss = np.sum(
(x - shift_value) * (x - shift_value),
axis=axis,
keepdims=keep_dims)
if shift is not None:
m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims)
v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims)
else:
shift_value = None
m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
count = 1.0
@ -334,8 +328,8 @@ class SufficientStatisticsTest(tf.test.TestCase):
if d in set(axes):
count *= x.shape[d]
if not keep_dims:
shift_value = np.squeeze(shift_value, axis=axis)
return count, m_ss, v_ss, shift_value
shift = np.squeeze(shift, axis=axis)
return count, m_ss, v_ss, shift
def _opSuffStats(self, x, axes, shift, keep_dims):
return tf.nn.sufficient_statistics(x, axes, shift, keep_dims)
@ -375,7 +369,7 @@ class SufficientStatisticsTest(tf.test.TestCase):
def testSuffStats(self):
for has_shape in [True, False]:
for keep_dims in [True, False]:
for shift in [True, False]:
for shift in [None, 1.0]:
self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
@ -419,7 +413,7 @@ class NormalizeMomentsTest(tf.test.TestCase):
self.assertAllClose(npv, tfv, atol=0.000001)
def testNormalizeMoments(self):
for shift in [True, False]:
for shift in [None, 4.0]:
self._testNormalizeMoments([3], shift)
self._testNormalizeMoments([2, 3], shift)