Add weighted_moments, and allow batch norm to use it to compute frequency-weighted statistics.

Change: 134717043
This commit is contained in:
A. Unique TensorFlower 2016-09-29 14:27:41 -08:00 committed by TensorFlower Gardener
parent bc0a56da15
commit ef9f5fee0a
3 changed files with 220 additions and 22 deletions
tensorflow
contrib/layers/python/layers
python/ops

View File

@ -123,6 +123,7 @@ def batch_norm(inputs,
variables_collections=None,
outputs_collections=None,
trainable=True,
batch_weights=None,
scope=None):
"""Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
@ -171,6 +172,11 @@ def batch_norm(inputs,
outputs_collections: collections to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
batch_weights: An optional tensor of shape `[batch_size]`,
containing a frequency weight for each batch item. If present,
then the batch normalization uses weighted mean and
variance. (This can be used to correct for bias in training
example selection.)
scope: Optional scope for `variable_scope`.
Returns:
@ -187,6 +193,14 @@ def batch_norm(inputs,
if inputs_rank is None:
raise ValueError('Inputs %s has undefined rank.' % inputs.name)
dtype = inputs.dtype.base_dtype
if batch_weights is not None:
batch_weights = ops.convert_to_tensor(batch_weights)
inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape())
# Reshape batch weight values so they broadcast across inputs.
nshape = [-1] + [1 for _ in range(inputs_rank - 1)]
batch_weights = array_ops.reshape(batch_weights, nshape)
axis = list(range(inputs_rank - 1))
params_shape = inputs_shape[-1:]
if not params_shape.is_fully_defined():
@ -240,9 +254,13 @@ def batch_norm(inputs,
need_moments = is_training_value is None or is_training_value
if need_moments:
# Calculate the moments based on the individual batch.
# Use a copy of moving_mean as a shift to compute more reliable moments.
shift = math_ops.add(moving_mean, 0)
mean, variance = nn.moments(inputs, axis, shift=shift)
if batch_weights is None:
# Use a copy of moving_mean as a shift to compute more reliable moments.
shift = math_ops.add(moving_mean, 0)
mean, variance = nn.moments(inputs, axis, shift=shift)
else:
mean, variance = nn.weighted_moments(inputs, axis, batch_weights)
moving_vars_fn = lambda: (moving_mean, moving_variance)
if updates_collections is None:
def _force_updates():

View File

@ -188,6 +188,7 @@ have varying scale, and to aid generalization.
@@sufficient_statistics
@@normalize_moments
@@moments
@@weighted_moments
## Losses
@ -819,7 +820,7 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
Args:
x: A `Tensor`.
axes: array of ints. Axes along which to compute mean and
axes: Array of ints. Axes along which to compute mean and
variance.
shift: A `Tensor` containing the value by which to shift the data for
numerical stability, or `None` if no shift is to be performed. A shift
@ -848,6 +849,82 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
return (mean, variance)
def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
"""Returns the frequency-weighted mean and variance of `x`.
Args:
x: A tensor.
axes: 1-d tensor of int32 values; these are the axes along which
to compute mean and variance.
frequency_weights: A tensor of positive weights which can be
broadcast with x.
name: Name used to scope the operation.
keep_dims: Produce moments with the same dimensionality as the input.
Returns:
Two tensors: `weighted_mean` and `weighted_variance`.
"""
with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]):
x = ops.convert_to_tensor(x, name="x")
frequency_weights = ops.convert_to_tensor(
frequency_weights, name="frequency_weights")
# Unlike moments(), this just uses a simpler two-pass method.
# See comment in moments() WRT precision; it applies here too.
needs_cast = x.dtype == dtypes.float16
if needs_cast:
x = math_ops.cast(x, dtypes.float32)
if frequency_weights.dtype != x.dtype:
frequency_weights = math_ops.cast(frequency_weights, x.dtype)
# Note that we use keep_dims=True for our reductions regardless of the arg;
# this is so that the results remain broadcast-compatible with the inputs.
weighted_input_sum = math_ops.reduce_sum(frequency_weights * x,
axes,
name="weighted_input_sum",
keep_dims=True)
# The shape of the weights isn't necessarily the same as x's
# shape, just broadcast-compatible with it -- so this expression
# performs broadcasting to give a per-item weight, with the same
# shape as (freqency_weights * x). This avoids having to reason
# through all the broadcast logic to compute a correct
# sum_of_weights.
broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
sum_of_weights = math_ops.reduce_sum(
broadcasted_weights,
axes,
name="sum_of_weights",
keep_dims=True)
divisor = math_ops.inv(sum_of_weights, name="inv_weight_sum")
weighted_mean = math_ops.mul(weighted_input_sum, divisor)
# Have the weighted mean; now on to variance:
weighted_distsq = math_ops.reduce_sum(
frequency_weights * math_ops.squared_difference(x, weighted_mean),
axes,
name="weighted_distsq",
keep_dims=True)
weighted_variance = math_ops.mul(weighted_distsq, divisor)
if not keep_dims:
weighted_mean = array_ops.squeeze(weighted_mean, squeeze_dims=axes)
weighted_variance = array_ops.squeeze(weighted_variance,
squeeze_dims=axes)
if needs_cast:
weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
weighted_variance = math_ops.cast(weighted_variance, dtypes.float16)
return weighted_mean, weighted_variance
def batch_normalization(x,
mean,
variance,

View File

@ -420,6 +420,16 @@ class NormalizeMomentsTest(tf.test.TestCase):
class MomentsTest(tf.test.TestCase):
def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
# Method to compute moments of `x` wrt `axes`.
#
# This is exposed so WeightedMomentsTest can inherit the tests and
# assertions from MomentsTest; the extra_out_grads argument allows
# its inherited gradient tests to assert gradients against the
# weights as well as the input values.
return tf.nn.moments(x, axes, keep_dims=keep_dims)
def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
with self.test_session():
# shape = [batch, width, height, depth]
@ -428,7 +438,7 @@ class MomentsTest(tf.test.TestCase):
x_numpy = np.random.normal(size=shape).astype(np.float32)
x = tf.placeholder(dtype, shape=[None] * len(shape))
mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims)
mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
num_elements = np.prod([shape[i] for i in axes])
@ -456,7 +466,11 @@ class MomentsTest(tf.test.TestCase):
x_numpy = np.random.normal(size=shape).astype(np.float32)
x = tf.cast(tf.constant(x_numpy), dtype=dtype)
mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims)
# Compute the expected values at high precision since the method
# is prone to catastrophic cancellation:
x_numpy = x_numpy.astype(np.float128)
mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
num_elements = np.prod([shape[i] for i in axes])
@ -519,14 +533,21 @@ class MomentsTest(tf.test.TestCase):
axes = [0, 1, 2]
y_shape = [2] # Depth of x
out_mean, out_var = tf.nn.moments(x, axes)
inputs_to_compute_gradients_for = [x]
out_mean, out_var = self._unweighted_moments(
x, axes, extra_out_grads=inputs_to_compute_gradients_for)
if from_y == "mean":
y = out_mean
elif from_y == "var":
y = out_var
err = tf.test.compute_gradient_error(x, x_shape, y, y_shape)
print("Moments %s gradient err = %g" % (from_y, err))
self.assertLess(err, 1e-11)
for (i, v) in enumerate(inputs_to_compute_gradients_for):
err = tf.test.compute_gradient_error(v, v.get_shape().as_list(),
y, y_shape)
print("Moments %s gradient err vs input %d = %g" % (from_y, i, err))
self.assertLess(err, 1e-11)
def testMeanGlobalGradient(self):
self._testGlobalGradient(from_y="mean")
@ -534,19 +555,101 @@ class MomentsTest(tf.test.TestCase):
def testVarGlobalGradient(self):
self._testGlobalGradient(from_y="var")
def testOutputNamesNoKeep(self):
"""Make sure the output names are stable."""
with self.test_session():
mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False)
self.assertEquals(mean.op.name, "moments/normalize/mean")
self.assertEquals(var.op.name, "moments/normalize/variance")
def testOutputNamesKeep(self):
"""Make sure the output names are stable."""
with self.test_session():
mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True)
self.assertEquals(mean.op.name, "moments/normalize/mean")
self.assertEquals(var.op.name, "moments/normalize/variance")
class WeightedMomentsTest(MomentsTest):
"""Tests for nn.weighted_moments.
Note that this test inherits from MomentsTest, inheriting all its
test methods!
It modifies MomentsTest in two ways:
a) By overriding _unweighted_moments, all the codepaths in
MomentsTest are executed, but with calls to tf.nn.moments()
replaced by calls to tf.nn.weighted_moments() with a constant
weight of 1.
b) By overriding RunMomentTest and RunMomentTestWithDynamicShape,
this test adds multiple additional calls to
RunWeightedMomentsTest() to exercise correctness with
non-constant weights and varying broadcasting situations. (It
also continues to call MomentsTest.Run(Weighted)?MomentsTest as
well.)
"""
def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
weights = tf.constant(1, dtype=x.dtype)
if extra_out_grads is not None:
# We want to assert gradients WRT weights as well as X!
extra_out_grads.append(weights)
return tf.nn.weighted_moments(
x, axes, weights, keep_dims=keep_dims)
def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False):
if not dynshapes:
super(WeightedMomentsTest, self).RunMomentTest(
shape, axes, keep_dims, dtype)
else:
super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape(
shape, axes, keep_dims, dtype)
# 1:1 weights and inputs
self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype)
# Various broadcasting combinations
for idx in range(len(shape)):
# try broadcasting weights in all positions
weight_shape = [1] * len(shape)
weight_shape[idx] = shape[idx]
self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype)
# Also try broadcasting with a suffix of length n
weight_shape = shape[-(idx+1):]
self.RunWeightedMomentTest(
shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes)
def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True)
def RunWeightedMomentTest(
self, shape, weights_shape, axes, keep_dims, dtype, dynshapes=False):
with self.test_session() as s:
x_numpy = np.random.normal(size=shape).astype(np.float32)
weights_numpy = np.absolute( # weights must be positive
np.random.normal(size=weights_shape, loc=1.0).astype(np.float32))
# Expand the numpy version to higher precision
x_numpy = x_numpy.astype(np.float128)
weights_numpy = weights_numpy.astype(np.float128)
x_shape = [None] * len(shape) if dynshapes else shape
weights_shape = (
[None] * len(weights_shape) if dynshapes else weights_shape)
x = tf.placeholder(dtype, shape=x_shape)
weights = tf.placeholder(dtype, shape=weights_shape)
mean, var = tf.nn.weighted_moments(x, axes, weights, keep_dims=keep_dims)
ax = tuple(axes)
def _np_weighted_sum(v):
return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims)
weight_sum = _np_weighted_sum(np.ones_like(x_numpy))
expected_mean = _np_weighted_sum(x_numpy) / weight_sum
expected_mean_squared = np.multiply(expected_mean, expected_mean)
expected_x_squared = (
_np_weighted_sum(np.multiply(x_numpy, x_numpy)) / weight_sum)
expected_variance = expected_x_squared - expected_mean_squared
mean_v, var_v = s.run([mean, var],
feed_dict={x: x_numpy, weights: weights_numpy})
self.assertAllCloseAccordingToType(expected_mean, mean_v)
self.assertAllCloseAccordingToType(expected_variance, var_v)
if __name__ == "__main__":