From 4ea0733c46ef5185c32a7cd80dd7f39807a92f5a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 May 2016 04:19:11 -0800 Subject: [PATCH] Allow callers to provide a shift value in tf.nn.moments(). Change: 123095477 --- RELEASE.md | 3 ++ tensorflow/python/ops/nn.py | 35 +++++++++++----------- tensorflow/python/ops/nn_batchnorm_test.py | 20 +++++-------- 3 files changed, 27 insertions(+), 31 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index c4d730e991e..0a7609c0bcc 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -10,6 +10,9 @@ ## Bug Fixes and Other Changes * TensorBoard now displays graphs with only one data point * TensorBoard now visually displays NaN values +* `tf.nn.moments()` now accepts a `shift` argument. Shifting by a good estimate + of the mean improves numerical stability. Also changes the behavior of the + `shift` argument to `tf.nn.sufficient_statistics()`. # Release 0.8.0 diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 92e4ed8c4d4..06f7305a158 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -587,7 +587,7 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides, padding="VALID", name=name) -def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None): +def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None): """Calculate the sufficient statistics for the mean and variance of `x`. These sufficient statistics are computed using the one pass algorithm on @@ -601,7 +601,9 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None): Args: x: A `Tensor`. axes: Array of ints. Axes along which to compute mean and variance. - shift: If true, shift the data to provide more numerically stable results. + shift: A `Tensor` containing the value by which to shift the data for + numerical stability, or `None` if no shift is to be performed. A shift + close to the true mean provides the most numerically stable results. keep_dims: produce statistics with the same dimensionality as the input. name: Name used to scope the operations that compute the sufficient stats. @@ -612,7 +614,7 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None): * the (possibly shifted) sum of squares of the elements in the array. * the shift by which the mean must be corrected or None if `shift` is False. """ - with ops.op_scope([x, axes], name, "sufficient_statistics"): + with ops.op_scope([x, axes, shift], name, "sufficient_statistics"): x = ops.convert_to_tensor(x, name="x") x_shape = x.get_shape() if x_shape.is_fully_defined(): @@ -635,23 +637,16 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None): math_ops.reduce_prod(x_shape / m_shape), x.dtype, name="count") - if shift: - shift_value = array_ops.slice(x, array_ops.zeros_like(m_shape), m_shape) - m_ss = math_ops.sub(x, shift_value) - v_ss = math_ops.squared_difference(x, shift_value) - if keep_dims: - shift_value = array_ops.identity(shift_value, name="shift") - else: - shift_value = array_ops.squeeze(shift_value, - squeeze_dims=axes, - name="shift") - else: # not shift. + if shift is not None: + shift = ops.convert_to_tensor(shift, name="shift") + m_ss = math_ops.sub(x, shift) + v_ss = math_ops.squared_difference(x, shift) + else: # no shift. m_ss = x v_ss = math_ops.square(x) - shift_value = None m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss") v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss") - return counts, m_ss, v_ss, shift_value + return counts, m_ss, v_ss, shift def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): @@ -685,7 +680,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): return (mean, variance) -def moments(x, axes, name=None, keep_dims=False): +def moments(x, axes, shift=None, name=None, keep_dims=False): """Calculate the mean and variance of `x`. The mean and variance are calculated by aggregating the contents of `x` @@ -702,15 +697,19 @@ def moments(x, axes, name=None, keep_dims=False): x: A `Tensor`. axes: array of ints. Axes along which to compute mean and variance. + shift: A `Tensor` containing the value by which to shift the data for + numerical stability, or `None` if no shift is to be performed. A shift + close to the true mean provides the most numerically stable results. keep_dims: produce moments with the same dimensionality as the input. name: Name used to scope the operations that compute the moments. Returns: Two `Tensor` objects: `mean` and `variance`. """ - with ops.op_scope([x, axes], name, "moments"): + with ops.op_scope([x, axes, shift], name, "moments"): counts, m_ss, v_ss, shift = sufficient_statistics(x, axes, + shift=shift, keep_dims=keep_dims, name=name) return normalize_moments(counts, m_ss, v_ss, shift, name=name) diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py index 5b9f33a73bb..c6a27a803c4 100644 --- a/tensorflow/python/ops/nn_batchnorm_test.py +++ b/tensorflow/python/ops/nn_batchnorm_test.py @@ -317,16 +317,10 @@ class SufficientStatisticsTest(tf.test.TestCase): def _npSuffStats(self, x, axes, shift, keep_dims): axis = tuple(axes) - if shift: - shift_value = x[[slice(None) if i not in set(axis) else slice(0, 1) - for i in xrange(x.ndim)]] - m_ss = np.sum(x - shift_value, axis=axis, keepdims=keep_dims) - v_ss = np.sum( - (x - shift_value) * (x - shift_value), - axis=axis, - keepdims=keep_dims) + if shift is not None: + m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims) + v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims) else: - shift_value = None m_ss = np.sum(x, axis=axis, keepdims=keep_dims) v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims) count = 1.0 @@ -334,8 +328,8 @@ class SufficientStatisticsTest(tf.test.TestCase): if d in set(axes): count *= x.shape[d] if not keep_dims: - shift_value = np.squeeze(shift_value, axis=axis) - return count, m_ss, v_ss, shift_value + shift = np.squeeze(shift, axis=axis) + return count, m_ss, v_ss, shift def _opSuffStats(self, x, axes, shift, keep_dims): return tf.nn.sufficient_statistics(x, axes, shift, keep_dims) @@ -375,7 +369,7 @@ class SufficientStatisticsTest(tf.test.TestCase): def testSuffStats(self): for has_shape in [True, False]: for keep_dims in [True, False]: - for shift in [True, False]: + for shift in [None, 1.0]: self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape) self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape) self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape) @@ -419,7 +413,7 @@ class NormalizeMomentsTest(tf.test.TestCase): self.assertAllClose(npv, tfv, atol=0.000001) def testNormalizeMoments(self): - for shift in [True, False]: + for shift in [None, 4.0]: self._testNormalizeMoments([3], shift) self._testNormalizeMoments([2, 3], shift)