diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 0ed396e4532..dc4ee9226a4 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -123,6 +123,7 @@ def batch_norm(inputs, variables_collections=None, outputs_collections=None, trainable=True, + batch_weights=None, scope=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. @@ -171,6 +172,11 @@ def batch_norm(inputs, outputs_collections: collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + batch_weights: An optional tensor of shape `[batch_size]`, + containing a frequency weight for each batch item. If present, + then the batch normalization uses weighted mean and + variance. (This can be used to correct for bias in training + example selection.) scope: Optional scope for `variable_scope`. Returns: @@ -187,6 +193,14 @@ def batch_norm(inputs, if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype + if batch_weights is not None: + batch_weights = ops.convert_to_tensor(batch_weights) + inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape()) + + # Reshape batch weight values so they broadcast across inputs. + nshape = [-1] + [1 for _ in range(inputs_rank - 1)] + batch_weights = array_ops.reshape(batch_weights, nshape) + axis = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] if not params_shape.is_fully_defined(): @@ -240,9 +254,13 @@ def batch_norm(inputs, need_moments = is_training_value is None or is_training_value if need_moments: # Calculate the moments based on the individual batch. - # Use a copy of moving_mean as a shift to compute more reliable moments. - shift = math_ops.add(moving_mean, 0) - mean, variance = nn.moments(inputs, axis, shift=shift) + if batch_weights is None: + # Use a copy of moving_mean as a shift to compute more reliable moments. + shift = math_ops.add(moving_mean, 0) + mean, variance = nn.moments(inputs, axis, shift=shift) + else: + mean, variance = nn.weighted_moments(inputs, axis, batch_weights) + moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: def _force_updates(): diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 958c32f0fc8..992e0f6f791 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -188,6 +188,7 @@ have varying scale, and to aid generalization. @@sufficient_statistics @@normalize_moments @@moments +@@weighted_moments ## Losses @@ -819,7 +820,7 @@ def moments(x, axes, shift=None, name=None, keep_dims=False): Args: x: A `Tensor`. - axes: array of ints. Axes along which to compute mean and + axes: Array of ints. Axes along which to compute mean and variance. shift: A `Tensor` containing the value by which to shift the data for numerical stability, or `None` if no shift is to be performed. A shift @@ -848,6 +849,82 @@ def moments(x, axes, shift=None, name=None, keep_dims=False): return (mean, variance) +def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False): + """Returns the frequency-weighted mean and variance of `x`. + + Args: + x: A tensor. + axes: 1-d tensor of int32 values; these are the axes along which + to compute mean and variance. + frequency_weights: A tensor of positive weights which can be + broadcast with x. + name: Name used to scope the operation. + keep_dims: Produce moments with the same dimensionality as the input. + + Returns: + Two tensors: `weighted_mean` and `weighted_variance`. + """ + with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]): + x = ops.convert_to_tensor(x, name="x") + frequency_weights = ops.convert_to_tensor( + frequency_weights, name="frequency_weights") + + # Unlike moments(), this just uses a simpler two-pass method. + + # See comment in moments() WRT precision; it applies here too. + needs_cast = x.dtype == dtypes.float16 + if needs_cast: + x = math_ops.cast(x, dtypes.float32) + + if frequency_weights.dtype != x.dtype: + frequency_weights = math_ops.cast(frequency_weights, x.dtype) + + # Note that we use keep_dims=True for our reductions regardless of the arg; + # this is so that the results remain broadcast-compatible with the inputs. + weighted_input_sum = math_ops.reduce_sum(frequency_weights * x, + axes, + name="weighted_input_sum", + keep_dims=True) + + # The shape of the weights isn't necessarily the same as x's + # shape, just broadcast-compatible with it -- so this expression + # performs broadcasting to give a per-item weight, with the same + # shape as (freqency_weights * x). This avoids having to reason + # through all the broadcast logic to compute a correct + # sum_of_weights. + broadcasted_weights = frequency_weights + array_ops.zeros_like(x) + + sum_of_weights = math_ops.reduce_sum( + broadcasted_weights, + axes, + name="sum_of_weights", + keep_dims=True) + + divisor = math_ops.inv(sum_of_weights, name="inv_weight_sum") + + weighted_mean = math_ops.mul(weighted_input_sum, divisor) + + # Have the weighted mean; now on to variance: + weighted_distsq = math_ops.reduce_sum( + frequency_weights * math_ops.squared_difference(x, weighted_mean), + axes, + name="weighted_distsq", + keep_dims=True) + + weighted_variance = math_ops.mul(weighted_distsq, divisor) + + if not keep_dims: + weighted_mean = array_ops.squeeze(weighted_mean, squeeze_dims=axes) + weighted_variance = array_ops.squeeze(weighted_variance, + squeeze_dims=axes) + + if needs_cast: + weighted_mean = math_ops.cast(weighted_mean, dtypes.float16) + weighted_variance = math_ops.cast(weighted_variance, dtypes.float16) + + return weighted_mean, weighted_variance + + def batch_normalization(x, mean, variance, diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py index 9ccf331c484..5e928fba569 100644 --- a/tensorflow/python/ops/nn_batchnorm_test.py +++ b/tensorflow/python/ops/nn_batchnorm_test.py @@ -420,6 +420,16 @@ class NormalizeMomentsTest(tf.test.TestCase): class MomentsTest(tf.test.TestCase): + def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None): + # Method to compute moments of `x` wrt `axes`. + # + # This is exposed so WeightedMomentsTest can inherit the tests and + # assertions from MomentsTest; the extra_out_grads argument allows + # its inherited gradient tests to assert gradients against the + # weights as well as the input values. + + return tf.nn.moments(x, axes, keep_dims=keep_dims) + def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype): with self.test_session(): # shape = [batch, width, height, depth] @@ -428,7 +438,7 @@ class MomentsTest(tf.test.TestCase): x_numpy = np.random.normal(size=shape).astype(np.float32) x = tf.placeholder(dtype, shape=[None] * len(shape)) - mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims) + mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims) num_elements = np.prod([shape[i] for i in axes]) @@ -456,7 +466,11 @@ class MomentsTest(tf.test.TestCase): x_numpy = np.random.normal(size=shape).astype(np.float32) x = tf.cast(tf.constant(x_numpy), dtype=dtype) - mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims) + # Compute the expected values at high precision since the method + # is prone to catastrophic cancellation: + x_numpy = x_numpy.astype(np.float128) + + mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims) num_elements = np.prod([shape[i] for i in axes]) @@ -519,14 +533,21 @@ class MomentsTest(tf.test.TestCase): axes = [0, 1, 2] y_shape = [2] # Depth of x - out_mean, out_var = tf.nn.moments(x, axes) + + inputs_to_compute_gradients_for = [x] + + out_mean, out_var = self._unweighted_moments( + x, axes, extra_out_grads=inputs_to_compute_gradients_for) if from_y == "mean": y = out_mean elif from_y == "var": y = out_var - err = tf.test.compute_gradient_error(x, x_shape, y, y_shape) - print("Moments %s gradient err = %g" % (from_y, err)) - self.assertLess(err, 1e-11) + + for (i, v) in enumerate(inputs_to_compute_gradients_for): + err = tf.test.compute_gradient_error(v, v.get_shape().as_list(), + y, y_shape) + print("Moments %s gradient err vs input %d = %g" % (from_y, i, err)) + self.assertLess(err, 1e-11) def testMeanGlobalGradient(self): self._testGlobalGradient(from_y="mean") @@ -534,19 +555,101 @@ class MomentsTest(tf.test.TestCase): def testVarGlobalGradient(self): self._testGlobalGradient(from_y="var") - def testOutputNamesNoKeep(self): - """Make sure the output names are stable.""" - with self.test_session(): - mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False) - self.assertEquals(mean.op.name, "moments/normalize/mean") - self.assertEquals(var.op.name, "moments/normalize/variance") - def testOutputNamesKeep(self): - """Make sure the output names are stable.""" - with self.test_session(): - mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True) - self.assertEquals(mean.op.name, "moments/normalize/mean") - self.assertEquals(var.op.name, "moments/normalize/variance") +class WeightedMomentsTest(MomentsTest): + """Tests for nn.weighted_moments. + + Note that this test inherits from MomentsTest, inheriting all its + test methods! + + It modifies MomentsTest in two ways: + + a) By overriding _unweighted_moments, all the codepaths in + MomentsTest are executed, but with calls to tf.nn.moments() + replaced by calls to tf.nn.weighted_moments() with a constant + weight of 1. + + b) By overriding RunMomentTest and RunMomentTestWithDynamicShape, + this test adds multiple additional calls to + RunWeightedMomentsTest() to exercise correctness with + non-constant weights and varying broadcasting situations. (It + also continues to call MomentsTest.Run(Weighted)?MomentsTest as + well.) + + """ + + def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None): + weights = tf.constant(1, dtype=x.dtype) + if extra_out_grads is not None: + # We want to assert gradients WRT weights as well as X! + extra_out_grads.append(weights) + return tf.nn.weighted_moments( + x, axes, weights, keep_dims=keep_dims) + + def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False): + if not dynshapes: + super(WeightedMomentsTest, self).RunMomentTest( + shape, axes, keep_dims, dtype) + else: + super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape( + shape, axes, keep_dims, dtype) + + # 1:1 weights and inputs + self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype) + + # Various broadcasting combinations + for idx in range(len(shape)): + # try broadcasting weights in all positions + weight_shape = [1] * len(shape) + weight_shape[idx] = shape[idx] + + self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype) + + # Also try broadcasting with a suffix of length n + weight_shape = shape[-(idx+1):] + self.RunWeightedMomentTest( + shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes) + + def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype): + self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True) + + def RunWeightedMomentTest( + self, shape, weights_shape, axes, keep_dims, dtype, dynshapes=False): + with self.test_session() as s: + x_numpy = np.random.normal(size=shape).astype(np.float32) + weights_numpy = np.absolute( # weights must be positive + np.random.normal(size=weights_shape, loc=1.0).astype(np.float32)) + + # Expand the numpy version to higher precision + x_numpy = x_numpy.astype(np.float128) + weights_numpy = weights_numpy.astype(np.float128) + + x_shape = [None] * len(shape) if dynshapes else shape + weights_shape = ( + [None] * len(weights_shape) if dynshapes else weights_shape) + + x = tf.placeholder(dtype, shape=x_shape) + weights = tf.placeholder(dtype, shape=weights_shape) + + mean, var = tf.nn.weighted_moments(x, axes, weights, keep_dims=keep_dims) + + ax = tuple(axes) + + def _np_weighted_sum(v): + return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims) + + weight_sum = _np_weighted_sum(np.ones_like(x_numpy)) + expected_mean = _np_weighted_sum(x_numpy) / weight_sum + expected_mean_squared = np.multiply(expected_mean, expected_mean) + expected_x_squared = ( + _np_weighted_sum(np.multiply(x_numpy, x_numpy)) / weight_sum) + expected_variance = expected_x_squared - expected_mean_squared + + mean_v, var_v = s.run([mean, var], + feed_dict={x: x_numpy, weights: weights_numpy}) + + self.assertAllCloseAccordingToType(expected_mean, mean_v) + self.assertAllCloseAccordingToType(expected_variance, var_v) if __name__ == "__main__":