diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index ecff8241a63..bc5ba95348a 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -134,6 +134,8 @@ have varying scale, and to aid generalization. @@l2_normalize @@local_response_normalization +@@sufficient_statistics +@@aggregate_moments @@moments ## Losses @@ -495,6 +497,101 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides, padding="VALID", name=name) +def sufficient_statistics(x, axes, shift=True, keep_dims=False, name=None): + """Calculate the sufficient statistics for the mean and variance of `x`. + + These sufficient statistics are computed using the one pass algorithm on + an input that's optionally shifted using the value of the 1st element in `x`. + See: + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data + + Args: + x: A `Tensor`. + axes: Array of ints. Axes along which to compute mean and variance. + shift: If true, shift the data to provide more numerically stable results. + keep_dims: produce statistics with the same dimensionality as the input. + name: Name used to scope the operations that compute the sufficient stats. + + Returns: + Four `Tensor` objects of the same type as `x`: + * the count (number of elements to average over). + * the (possibly shifted) sum of the elements in the array. + * the (possibly shifted) sum of squares of the elements in the array. + * the shift by which the mean must be corrected or None if `shift` is False. + """ + with ops.op_scope([x, axes], name, "sufficient_statistics"): + x = ops.convert_to_tensor(x, name="x") + x_shape = x.get_shape() + if x_shape.is_fully_defined(): + counts = 1 + m_shape = [] + for d in xrange(x_shape.ndims): + dim = x_shape[d].value + if d in set(axes): + counts *= dim + dim = 1 + m_shape.append(dim) + counts = constant_op.constant(counts, dtype=x.dtype) + else: # shape needs to be inferred at runtime. + x_shape = array_ops.shape(x) + select_axes = sparse_ops.sparse_to_dense(axes, array_ops.shape(x_shape), + True, False) + m_shape = math_ops.select(select_axes, array_ops.ones_like(x_shape), + x_shape) + counts = math_ops.cast( + math_ops.reduce_prod(x_shape / m_shape), + x.dtype, + name="count") + if shift: + shift_value = array_ops.slice(x, array_ops.zeros_like(m_shape), m_shape) + m_ss = math_ops.sub(x, shift_value) + v_ss = math_ops.squared_difference(x, shift_value) + if keep_dims: + shift_value = array_ops.identity(shift_value, name="shift") + else: + shift_value = array_ops.squeeze(shift_value, + squeeze_dims=axes, + name="shift") + else: # not shift. + m_ss = x + v_ss = math_ops.square(x) + shift_value = None + m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss") + v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss") + return counts, m_ss, v_ss, shift_value + + +def aggregate_moments(counts, mean_ss, variance_ss, shift, name=None): + """Calculate the mean and variance of based on the sufficient statistics. + + Args: + counts: A `Tensor` containing a the total count of the data (one value). + mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly + shifted) sum of the elements to average over. + variance_ss: A `Tensor` containing the variance sufficient statistics: the + (possibly shifted) squared sum of the data to compute the variance over. + shift: A `Tensor` containing the value by which the data is shifted for + numerical stability, or `None` if no shift was performed. + name: Name used to scope the operations that compute the moments. + + Returns: + Two `Tensor` objects: `mean` and `variance`. + """ + with ops.op_scope([counts, mean_ss, variance_ss, shift], name, "aggregate"): + divisor = math_ops.inv(counts, name="divisor") + if shift is not None: + shifted_mean = math_ops.mul(mean_ss, divisor, name="shifted_mean") + mean = math_ops.add(shifted_mean, shift, name="mean") + else: # no shift. + shifted_mean = math_ops.mul(mean_ss, divisor, name="mean") + mean = shifted_mean + variance = math_ops.sub( + math_ops.mul(variance_ss, divisor), + math_ops.square(shifted_mean), + name="variance") + return (mean, variance) + + def moments(x, axes, name=None, keep_dims=False): """Calculate the mean and variance of `x`. @@ -519,40 +616,11 @@ def moments(x, axes, name=None, keep_dims=False): Two `Tensor` objects: `mean` and `variance`. """ with ops.op_scope([x, axes], name, "moments"): - x = ops.convert_to_tensor(x, name="x") - x_shape = x.get_shape() - if all(x_shape[d].value is not None for d in axes): - # The shape is known in the relevant axes, so we can statically - # compute the divisor. - divisor = 1.0 - for d in set(axes): - divisor *= x.get_shape()[d].value - divisor = constant_op.constant(1.0 / divisor, x.dtype, name="divisor") - else: - divisor = constant_op.constant(1.0, dtype=x.dtype) - x_dynamic_shape = array_ops.shape(x) - for d in set(axes): - divisor *= math_ops.cast(x_dynamic_shape[d], x.dtype) - divisor = math_ops.inv(divisor, name="divisor") - constant_axes = constant_op.constant(axes, name="axes") - # Note: We do not use Mean here because it is very slow on GPU. - mean = math_ops.mul( - math_ops.reduce_sum(x, - constant_axes, - keep_dims=True), - divisor, - name="mean") - var = math_ops.mul( - math_ops.reduce_sum( - math_ops.squared_difference(x, mean), - constant_axes, - keep_dims=keep_dims), - divisor, - name="variance") - if keep_dims: - return mean, var - else: - return array_ops.squeeze(mean, squeeze_dims=axes), var + counts, m_ss, v_ss, shift = sufficient_statistics(x, + axes, + keep_dims=keep_dims, + name=name) + return aggregate_moments(counts, m_ss, v_ss, shift, name=name) def batch_normalization(x, diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 30c866e6a4c..317a0748309 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -476,7 +476,7 @@ class DropoutTest(tf.test.TestCase): _ = tf.nn.dropout(t, keep_prob, noise_shape=[1, 1]) -class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): +class BatchNormalizationTest(tf.test.TestCase): def _npBatchNorm(self, x, m, v, beta, gamma, epsilon, scale_after_normalization, shift_after_normalization): @@ -670,8 +670,7 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): else: all_grads = sess.run([dx, dm, dv, db, odx, odm, odv, odb]) to_check = ["dx", "dm", "dv", "db"] - for i, n in enumerate(to_check): - print(n) + for i, _ in enumerate(to_check): self.assertAllClose( all_grads[i + len(to_check)], all_grads[i], atol=0.000001) @@ -759,6 +758,117 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): atol=0.005) +class SufficientStatisticsTest(tf.test.TestCase): + + def _npSuffStats(self, x, axes, shift, keep_dims): + axis = tuple(axes) + if shift: + shift_value = x[[slice(None) if i not in set(axis) else slice(0, 1) + for i in xrange(x.ndim)]] + m_ss = np.sum(x - shift_value, axis=axis, keepdims=keep_dims) + v_ss = np.sum( + (x - shift_value) * (x - shift_value), + axis=axis, + keepdims=keep_dims) + else: + shift_value = None + m_ss = np.sum(x, axis=axis, keepdims=keep_dims) + v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims) + count = 1.0 + for d in xrange(x.ndim): + if d in set(axes): + count *= x.shape[d] + if not keep_dims: + shift_value = np.squeeze(shift_value, axis=axis) + return count, m_ss, v_ss, shift_value + + def _opSuffStats(self, x, axes, shift, keep_dims): + return tf.nn.sufficient_statistics(x, axes, shift, keep_dims) + + def _testSuffStats(self, x_shape, axes, shift, keep_dims, has_shape): + x_val = np.random.random_sample(x_shape).astype(np.float32) + np_c, np_m, np_v, np_s = self._npSuffStats(x_val, axes, shift, keep_dims) + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu) as sess: + if has_shape: + x = tf.constant(x_val, name="x") + x.set_shape(x_shape) + op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims) + if shift: + tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s]) + else: + tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v]) + else: + x = tf.placeholder(dtype=tf.float32, + shape=[None] * len(x_shape), + name="x") + op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims) + if shift: + tf_c, tf_m, tf_v, tf_s = sess.run( + [op_c, op_m, op_v, op_s], + feed_dict={x: x_val}) + else: + tf_c, tf_m, tf_v = sess.run( + [op_c, op_m, op_v], + feed_dict={x: x_val}) + self.assertAllClose(np_c, tf_c, atol=0.000001) + self.assertAllClose(np_m, tf_m, atol=0.000001) + self.assertAllClose(np_v, tf_v, atol=0.000001) + if shift: + self.assertAllClose(np_s, tf_s, atol=0.000001) + + def testSuffStats(self): + for has_shape in [True, False]: + for keep_dims in [True, False]: + for shift in [True, False]: + self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape) + self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape) + self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape) + + +class AggregateMomentsTest(tf.test.TestCase): + + def _npAggregateMoments(self, counts, mean_ss, variance_ss, shift): + mean = mean_ss / counts + variance = variance_ss / counts - mean * mean + if shift is not None: + mean += shift + return mean, variance + + def _opAggregateMoments(self, counts, mean_ss, variance_ss, shift): + return tf.nn.aggregate_moments(counts, mean_ss, variance_ss, shift) + + def _testAggregateMoments(self, shape, shift): + counts = np.ones([1]).astype(np.float32) + mean_ss = np.random.random_sample(shape).astype(np.float32) + variance_ss = np.random.random_sample(shape).astype(np.float32) + variance_ss *= variance_ss + if shift: + shift_v = np.random.random_sample(shape).astype(np.float32) + else: + shift_v = None + npm, npv = self._npAggregateMoments(counts, mean_ss, variance_ss, shift_v) + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu) as sess: + tf_counts = tf.constant(counts, name="counts") + tf_mean_ss = tf.constant(mean_ss, name="mean_ss") + tf_variance_ss = tf.constant(variance_ss, name="variance_ss") + if shift: + tf_shift_v = tf.constant(shift_v, name="shift") + else: + tf_shift_v = None + opm, opv = self._opAggregateMoments(tf_counts, tf_mean_ss, + tf_variance_ss, tf_shift_v) + tfm, tfv = sess.run([opm, opv]) + self.assertAllClose(npm, tfm, atol=0.000001) + self.assertAllClose(npv, tfv, atol=0.000001) + + def testAggregateMoments(self): + for shift in [True, False]: + self._testAggregateMoments([3], shift) + self._testAggregateMoments([2, 3], shift) + + class MomentsTest(tf.test.TestCase): def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims): @@ -857,6 +967,20 @@ class MomentsTest(tf.test.TestCase): def testVarGlobalGradient(self): self._testGlobalGradient(from_y="var") + def testOutputNamesNoKeep(self): + """Make sure the output names are stable.""" + with self.test_session(): + mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False) + self.assertEquals(mean.op.name, "moments/aggregate/mean") + self.assertEquals(var.op.name, "moments/aggregate/variance") + + def testOutputNamesKeep(self): + """Make sure the output names are stable.""" + with self.test_session(): + mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True) + self.assertEquals(mean.op.name, "moments/aggregate/mean") + self.assertEquals(var.op.name, "moments/aggregate/variance") + class ComputeSampledLogitsTest(tf.test.TestCase):