Allow callers to provide a shift value in tf.nn.moments().
Change: 123095477
This commit is contained in:
parent
dcfeb027ac
commit
4ea0733c46
@ -10,6 +10,9 @@
|
||||
## Bug Fixes and Other Changes
|
||||
* TensorBoard now displays graphs with only one data point
|
||||
* TensorBoard now visually displays NaN values
|
||||
* `tf.nn.moments()` now accepts a `shift` argument. Shifting by a good estimate
|
||||
of the mean improves numerical stability. Also changes the behavior of the
|
||||
`shift` argument to `tf.nn.sufficient_statistics()`.
|
||||
|
||||
# Release 0.8.0
|
||||
|
||||
|
@ -587,7 +587,7 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides,
|
||||
padding="VALID", name=name)
|
||||
|
||||
|
||||
def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
|
||||
def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
|
||||
"""Calculate the sufficient statistics for the mean and variance of `x`.
|
||||
|
||||
These sufficient statistics are computed using the one pass algorithm on
|
||||
@ -601,7 +601,9 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
|
||||
Args:
|
||||
x: A `Tensor`.
|
||||
axes: Array of ints. Axes along which to compute mean and variance.
|
||||
shift: If true, shift the data to provide more numerically stable results.
|
||||
shift: A `Tensor` containing the value by which to shift the data for
|
||||
numerical stability, or `None` if no shift is to be performed. A shift
|
||||
close to the true mean provides the most numerically stable results.
|
||||
keep_dims: produce statistics with the same dimensionality as the input.
|
||||
name: Name used to scope the operations that compute the sufficient stats.
|
||||
|
||||
@ -612,7 +614,7 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
|
||||
* the (possibly shifted) sum of squares of the elements in the array.
|
||||
* the shift by which the mean must be corrected or None if `shift` is False.
|
||||
"""
|
||||
with ops.op_scope([x, axes], name, "sufficient_statistics"):
|
||||
with ops.op_scope([x, axes, shift], name, "sufficient_statistics"):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
x_shape = x.get_shape()
|
||||
if x_shape.is_fully_defined():
|
||||
@ -635,23 +637,16 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
|
||||
math_ops.reduce_prod(x_shape / m_shape),
|
||||
x.dtype,
|
||||
name="count")
|
||||
if shift:
|
||||
shift_value = array_ops.slice(x, array_ops.zeros_like(m_shape), m_shape)
|
||||
m_ss = math_ops.sub(x, shift_value)
|
||||
v_ss = math_ops.squared_difference(x, shift_value)
|
||||
if keep_dims:
|
||||
shift_value = array_ops.identity(shift_value, name="shift")
|
||||
else:
|
||||
shift_value = array_ops.squeeze(shift_value,
|
||||
squeeze_dims=axes,
|
||||
name="shift")
|
||||
else: # not shift.
|
||||
if shift is not None:
|
||||
shift = ops.convert_to_tensor(shift, name="shift")
|
||||
m_ss = math_ops.sub(x, shift)
|
||||
v_ss = math_ops.squared_difference(x, shift)
|
||||
else: # no shift.
|
||||
m_ss = x
|
||||
v_ss = math_ops.square(x)
|
||||
shift_value = None
|
||||
m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss")
|
||||
v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss")
|
||||
return counts, m_ss, v_ss, shift_value
|
||||
return counts, m_ss, v_ss, shift
|
||||
|
||||
|
||||
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
|
||||
@ -685,7 +680,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
|
||||
return (mean, variance)
|
||||
|
||||
|
||||
def moments(x, axes, name=None, keep_dims=False):
|
||||
def moments(x, axes, shift=None, name=None, keep_dims=False):
|
||||
"""Calculate the mean and variance of `x`.
|
||||
|
||||
The mean and variance are calculated by aggregating the contents of `x`
|
||||
@ -702,15 +697,19 @@ def moments(x, axes, name=None, keep_dims=False):
|
||||
x: A `Tensor`.
|
||||
axes: array of ints. Axes along which to compute mean and
|
||||
variance.
|
||||
shift: A `Tensor` containing the value by which to shift the data for
|
||||
numerical stability, or `None` if no shift is to be performed. A shift
|
||||
close to the true mean provides the most numerically stable results.
|
||||
keep_dims: produce moments with the same dimensionality as the input.
|
||||
name: Name used to scope the operations that compute the moments.
|
||||
|
||||
Returns:
|
||||
Two `Tensor` objects: `mean` and `variance`.
|
||||
"""
|
||||
with ops.op_scope([x, axes], name, "moments"):
|
||||
with ops.op_scope([x, axes, shift], name, "moments"):
|
||||
counts, m_ss, v_ss, shift = sufficient_statistics(x,
|
||||
axes,
|
||||
shift=shift,
|
||||
keep_dims=keep_dims,
|
||||
name=name)
|
||||
return normalize_moments(counts, m_ss, v_ss, shift, name=name)
|
||||
|
@ -317,16 +317,10 @@ class SufficientStatisticsTest(tf.test.TestCase):
|
||||
|
||||
def _npSuffStats(self, x, axes, shift, keep_dims):
|
||||
axis = tuple(axes)
|
||||
if shift:
|
||||
shift_value = x[[slice(None) if i not in set(axis) else slice(0, 1)
|
||||
for i in xrange(x.ndim)]]
|
||||
m_ss = np.sum(x - shift_value, axis=axis, keepdims=keep_dims)
|
||||
v_ss = np.sum(
|
||||
(x - shift_value) * (x - shift_value),
|
||||
axis=axis,
|
||||
keepdims=keep_dims)
|
||||
if shift is not None:
|
||||
m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims)
|
||||
v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims)
|
||||
else:
|
||||
shift_value = None
|
||||
m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
|
||||
v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
|
||||
count = 1.0
|
||||
@ -334,8 +328,8 @@ class SufficientStatisticsTest(tf.test.TestCase):
|
||||
if d in set(axes):
|
||||
count *= x.shape[d]
|
||||
if not keep_dims:
|
||||
shift_value = np.squeeze(shift_value, axis=axis)
|
||||
return count, m_ss, v_ss, shift_value
|
||||
shift = np.squeeze(shift, axis=axis)
|
||||
return count, m_ss, v_ss, shift
|
||||
|
||||
def _opSuffStats(self, x, axes, shift, keep_dims):
|
||||
return tf.nn.sufficient_statistics(x, axes, shift, keep_dims)
|
||||
@ -375,7 +369,7 @@ class SufficientStatisticsTest(tf.test.TestCase):
|
||||
def testSuffStats(self):
|
||||
for has_shape in [True, False]:
|
||||
for keep_dims in [True, False]:
|
||||
for shift in [True, False]:
|
||||
for shift in [None, 1.0]:
|
||||
self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
|
||||
self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
|
||||
self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
|
||||
@ -419,7 +413,7 @@ class NormalizeMomentsTest(tf.test.TestCase):
|
||||
self.assertAllClose(npv, tfv, atol=0.000001)
|
||||
|
||||
def testNormalizeMoments(self):
|
||||
for shift in [True, False]:
|
||||
for shift in [None, 4.0]:
|
||||
self._testNormalizeMoments([3], shift)
|
||||
self._testNormalizeMoments([2, 3], shift)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user