Allow callers to provide a shift value in tf.nn.moments().
Change: 123095477
This commit is contained in:
parent
dcfeb027ac
commit
4ea0733c46
@ -10,6 +10,9 @@
|
|||||||
## Bug Fixes and Other Changes
|
## Bug Fixes and Other Changes
|
||||||
* TensorBoard now displays graphs with only one data point
|
* TensorBoard now displays graphs with only one data point
|
||||||
* TensorBoard now visually displays NaN values
|
* TensorBoard now visually displays NaN values
|
||||||
|
* `tf.nn.moments()` now accepts a `shift` argument. Shifting by a good estimate
|
||||||
|
of the mean improves numerical stability. Also changes the behavior of the
|
||||||
|
`shift` argument to `tf.nn.sufficient_statistics()`.
|
||||||
|
|
||||||
# Release 0.8.0
|
# Release 0.8.0
|
||||||
|
|
||||||
|
@ -587,7 +587,7 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides,
|
|||||||
padding="VALID", name=name)
|
padding="VALID", name=name)
|
||||||
|
|
||||||
|
|
||||||
def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
|
def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
|
||||||
"""Calculate the sufficient statistics for the mean and variance of `x`.
|
"""Calculate the sufficient statistics for the mean and variance of `x`.
|
||||||
|
|
||||||
These sufficient statistics are computed using the one pass algorithm on
|
These sufficient statistics are computed using the one pass algorithm on
|
||||||
@ -601,7 +601,9 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
|
|||||||
Args:
|
Args:
|
||||||
x: A `Tensor`.
|
x: A `Tensor`.
|
||||||
axes: Array of ints. Axes along which to compute mean and variance.
|
axes: Array of ints. Axes along which to compute mean and variance.
|
||||||
shift: If true, shift the data to provide more numerically stable results.
|
shift: A `Tensor` containing the value by which to shift the data for
|
||||||
|
numerical stability, or `None` if no shift is to be performed. A shift
|
||||||
|
close to the true mean provides the most numerically stable results.
|
||||||
keep_dims: produce statistics with the same dimensionality as the input.
|
keep_dims: produce statistics with the same dimensionality as the input.
|
||||||
name: Name used to scope the operations that compute the sufficient stats.
|
name: Name used to scope the operations that compute the sufficient stats.
|
||||||
|
|
||||||
@ -612,7 +614,7 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
|
|||||||
* the (possibly shifted) sum of squares of the elements in the array.
|
* the (possibly shifted) sum of squares of the elements in the array.
|
||||||
* the shift by which the mean must be corrected or None if `shift` is False.
|
* the shift by which the mean must be corrected or None if `shift` is False.
|
||||||
"""
|
"""
|
||||||
with ops.op_scope([x, axes], name, "sufficient_statistics"):
|
with ops.op_scope([x, axes, shift], name, "sufficient_statistics"):
|
||||||
x = ops.convert_to_tensor(x, name="x")
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
x_shape = x.get_shape()
|
x_shape = x.get_shape()
|
||||||
if x_shape.is_fully_defined():
|
if x_shape.is_fully_defined():
|
||||||
@ -635,23 +637,16 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
|
|||||||
math_ops.reduce_prod(x_shape / m_shape),
|
math_ops.reduce_prod(x_shape / m_shape),
|
||||||
x.dtype,
|
x.dtype,
|
||||||
name="count")
|
name="count")
|
||||||
if shift:
|
if shift is not None:
|
||||||
shift_value = array_ops.slice(x, array_ops.zeros_like(m_shape), m_shape)
|
shift = ops.convert_to_tensor(shift, name="shift")
|
||||||
m_ss = math_ops.sub(x, shift_value)
|
m_ss = math_ops.sub(x, shift)
|
||||||
v_ss = math_ops.squared_difference(x, shift_value)
|
v_ss = math_ops.squared_difference(x, shift)
|
||||||
if keep_dims:
|
else: # no shift.
|
||||||
shift_value = array_ops.identity(shift_value, name="shift")
|
|
||||||
else:
|
|
||||||
shift_value = array_ops.squeeze(shift_value,
|
|
||||||
squeeze_dims=axes,
|
|
||||||
name="shift")
|
|
||||||
else: # not shift.
|
|
||||||
m_ss = x
|
m_ss = x
|
||||||
v_ss = math_ops.square(x)
|
v_ss = math_ops.square(x)
|
||||||
shift_value = None
|
|
||||||
m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss")
|
m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss")
|
||||||
v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss")
|
v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss")
|
||||||
return counts, m_ss, v_ss, shift_value
|
return counts, m_ss, v_ss, shift
|
||||||
|
|
||||||
|
|
||||||
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
|
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
|
||||||
@ -685,7 +680,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
|
|||||||
return (mean, variance)
|
return (mean, variance)
|
||||||
|
|
||||||
|
|
||||||
def moments(x, axes, name=None, keep_dims=False):
|
def moments(x, axes, shift=None, name=None, keep_dims=False):
|
||||||
"""Calculate the mean and variance of `x`.
|
"""Calculate the mean and variance of `x`.
|
||||||
|
|
||||||
The mean and variance are calculated by aggregating the contents of `x`
|
The mean and variance are calculated by aggregating the contents of `x`
|
||||||
@ -702,15 +697,19 @@ def moments(x, axes, name=None, keep_dims=False):
|
|||||||
x: A `Tensor`.
|
x: A `Tensor`.
|
||||||
axes: array of ints. Axes along which to compute mean and
|
axes: array of ints. Axes along which to compute mean and
|
||||||
variance.
|
variance.
|
||||||
|
shift: A `Tensor` containing the value by which to shift the data for
|
||||||
|
numerical stability, or `None` if no shift is to be performed. A shift
|
||||||
|
close to the true mean provides the most numerically stable results.
|
||||||
keep_dims: produce moments with the same dimensionality as the input.
|
keep_dims: produce moments with the same dimensionality as the input.
|
||||||
name: Name used to scope the operations that compute the moments.
|
name: Name used to scope the operations that compute the moments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Two `Tensor` objects: `mean` and `variance`.
|
Two `Tensor` objects: `mean` and `variance`.
|
||||||
"""
|
"""
|
||||||
with ops.op_scope([x, axes], name, "moments"):
|
with ops.op_scope([x, axes, shift], name, "moments"):
|
||||||
counts, m_ss, v_ss, shift = sufficient_statistics(x,
|
counts, m_ss, v_ss, shift = sufficient_statistics(x,
|
||||||
axes,
|
axes,
|
||||||
|
shift=shift,
|
||||||
keep_dims=keep_dims,
|
keep_dims=keep_dims,
|
||||||
name=name)
|
name=name)
|
||||||
return normalize_moments(counts, m_ss, v_ss, shift, name=name)
|
return normalize_moments(counts, m_ss, v_ss, shift, name=name)
|
||||||
|
@ -317,16 +317,10 @@ class SufficientStatisticsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def _npSuffStats(self, x, axes, shift, keep_dims):
|
def _npSuffStats(self, x, axes, shift, keep_dims):
|
||||||
axis = tuple(axes)
|
axis = tuple(axes)
|
||||||
if shift:
|
if shift is not None:
|
||||||
shift_value = x[[slice(None) if i not in set(axis) else slice(0, 1)
|
m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims)
|
||||||
for i in xrange(x.ndim)]]
|
v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims)
|
||||||
m_ss = np.sum(x - shift_value, axis=axis, keepdims=keep_dims)
|
|
||||||
v_ss = np.sum(
|
|
||||||
(x - shift_value) * (x - shift_value),
|
|
||||||
axis=axis,
|
|
||||||
keepdims=keep_dims)
|
|
||||||
else:
|
else:
|
||||||
shift_value = None
|
|
||||||
m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
|
m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
|
||||||
v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
|
v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
|
||||||
count = 1.0
|
count = 1.0
|
||||||
@ -334,8 +328,8 @@ class SufficientStatisticsTest(tf.test.TestCase):
|
|||||||
if d in set(axes):
|
if d in set(axes):
|
||||||
count *= x.shape[d]
|
count *= x.shape[d]
|
||||||
if not keep_dims:
|
if not keep_dims:
|
||||||
shift_value = np.squeeze(shift_value, axis=axis)
|
shift = np.squeeze(shift, axis=axis)
|
||||||
return count, m_ss, v_ss, shift_value
|
return count, m_ss, v_ss, shift
|
||||||
|
|
||||||
def _opSuffStats(self, x, axes, shift, keep_dims):
|
def _opSuffStats(self, x, axes, shift, keep_dims):
|
||||||
return tf.nn.sufficient_statistics(x, axes, shift, keep_dims)
|
return tf.nn.sufficient_statistics(x, axes, shift, keep_dims)
|
||||||
@ -375,7 +369,7 @@ class SufficientStatisticsTest(tf.test.TestCase):
|
|||||||
def testSuffStats(self):
|
def testSuffStats(self):
|
||||||
for has_shape in [True, False]:
|
for has_shape in [True, False]:
|
||||||
for keep_dims in [True, False]:
|
for keep_dims in [True, False]:
|
||||||
for shift in [True, False]:
|
for shift in [None, 1.0]:
|
||||||
self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
|
self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
|
||||||
self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
|
self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
|
||||||
self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
|
self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
|
||||||
@ -419,7 +413,7 @@ class NormalizeMomentsTest(tf.test.TestCase):
|
|||||||
self.assertAllClose(npv, tfv, atol=0.000001)
|
self.assertAllClose(npv, tfv, atol=0.000001)
|
||||||
|
|
||||||
def testNormalizeMoments(self):
|
def testNormalizeMoments(self):
|
||||||
for shift in [True, False]:
|
for shift in [None, 4.0]:
|
||||||
self._testNormalizeMoments([3], shift)
|
self._testNormalizeMoments([3], shift)
|
||||||
self._testNormalizeMoments([2, 3], shift)
|
self._testNormalizeMoments([2, 3], shift)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user