Add weighted_moments, and allow batch norm to use it to compute frequency-weighted statistics.
Change: 134717043
This commit is contained in:
parent
bc0a56da15
commit
ef9f5fee0a
tensorflow
@ -123,6 +123,7 @@ def batch_norm(inputs,
|
||||
variables_collections=None,
|
||||
outputs_collections=None,
|
||||
trainable=True,
|
||||
batch_weights=None,
|
||||
scope=None):
|
||||
"""Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
|
||||
|
||||
@ -171,6 +172,11 @@ def batch_norm(inputs,
|
||||
outputs_collections: collections to add the outputs.
|
||||
trainable: If `True` also add variables to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
||||
batch_weights: An optional tensor of shape `[batch_size]`,
|
||||
containing a frequency weight for each batch item. If present,
|
||||
then the batch normalization uses weighted mean and
|
||||
variance. (This can be used to correct for bias in training
|
||||
example selection.)
|
||||
scope: Optional scope for `variable_scope`.
|
||||
|
||||
Returns:
|
||||
@ -187,6 +193,14 @@ def batch_norm(inputs,
|
||||
if inputs_rank is None:
|
||||
raise ValueError('Inputs %s has undefined rank.' % inputs.name)
|
||||
dtype = inputs.dtype.base_dtype
|
||||
if batch_weights is not None:
|
||||
batch_weights = ops.convert_to_tensor(batch_weights)
|
||||
inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape())
|
||||
|
||||
# Reshape batch weight values so they broadcast across inputs.
|
||||
nshape = [-1] + [1 for _ in range(inputs_rank - 1)]
|
||||
batch_weights = array_ops.reshape(batch_weights, nshape)
|
||||
|
||||
axis = list(range(inputs_rank - 1))
|
||||
params_shape = inputs_shape[-1:]
|
||||
if not params_shape.is_fully_defined():
|
||||
@ -240,9 +254,13 @@ def batch_norm(inputs,
|
||||
need_moments = is_training_value is None or is_training_value
|
||||
if need_moments:
|
||||
# Calculate the moments based on the individual batch.
|
||||
# Use a copy of moving_mean as a shift to compute more reliable moments.
|
||||
shift = math_ops.add(moving_mean, 0)
|
||||
mean, variance = nn.moments(inputs, axis, shift=shift)
|
||||
if batch_weights is None:
|
||||
# Use a copy of moving_mean as a shift to compute more reliable moments.
|
||||
shift = math_ops.add(moving_mean, 0)
|
||||
mean, variance = nn.moments(inputs, axis, shift=shift)
|
||||
else:
|
||||
mean, variance = nn.weighted_moments(inputs, axis, batch_weights)
|
||||
|
||||
moving_vars_fn = lambda: (moving_mean, moving_variance)
|
||||
if updates_collections is None:
|
||||
def _force_updates():
|
||||
|
@ -188,6 +188,7 @@ have varying scale, and to aid generalization.
|
||||
@@sufficient_statistics
|
||||
@@normalize_moments
|
||||
@@moments
|
||||
@@weighted_moments
|
||||
|
||||
## Losses
|
||||
|
||||
@ -819,7 +820,7 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
|
||||
|
||||
Args:
|
||||
x: A `Tensor`.
|
||||
axes: array of ints. Axes along which to compute mean and
|
||||
axes: Array of ints. Axes along which to compute mean and
|
||||
variance.
|
||||
shift: A `Tensor` containing the value by which to shift the data for
|
||||
numerical stability, or `None` if no shift is to be performed. A shift
|
||||
@ -848,6 +849,82 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
|
||||
return (mean, variance)
|
||||
|
||||
|
||||
def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
|
||||
"""Returns the frequency-weighted mean and variance of `x`.
|
||||
|
||||
Args:
|
||||
x: A tensor.
|
||||
axes: 1-d tensor of int32 values; these are the axes along which
|
||||
to compute mean and variance.
|
||||
frequency_weights: A tensor of positive weights which can be
|
||||
broadcast with x.
|
||||
name: Name used to scope the operation.
|
||||
keep_dims: Produce moments with the same dimensionality as the input.
|
||||
|
||||
Returns:
|
||||
Two tensors: `weighted_mean` and `weighted_variance`.
|
||||
"""
|
||||
with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
frequency_weights = ops.convert_to_tensor(
|
||||
frequency_weights, name="frequency_weights")
|
||||
|
||||
# Unlike moments(), this just uses a simpler two-pass method.
|
||||
|
||||
# See comment in moments() WRT precision; it applies here too.
|
||||
needs_cast = x.dtype == dtypes.float16
|
||||
if needs_cast:
|
||||
x = math_ops.cast(x, dtypes.float32)
|
||||
|
||||
if frequency_weights.dtype != x.dtype:
|
||||
frequency_weights = math_ops.cast(frequency_weights, x.dtype)
|
||||
|
||||
# Note that we use keep_dims=True for our reductions regardless of the arg;
|
||||
# this is so that the results remain broadcast-compatible with the inputs.
|
||||
weighted_input_sum = math_ops.reduce_sum(frequency_weights * x,
|
||||
axes,
|
||||
name="weighted_input_sum",
|
||||
keep_dims=True)
|
||||
|
||||
# The shape of the weights isn't necessarily the same as x's
|
||||
# shape, just broadcast-compatible with it -- so this expression
|
||||
# performs broadcasting to give a per-item weight, with the same
|
||||
# shape as (freqency_weights * x). This avoids having to reason
|
||||
# through all the broadcast logic to compute a correct
|
||||
# sum_of_weights.
|
||||
broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
|
||||
|
||||
sum_of_weights = math_ops.reduce_sum(
|
||||
broadcasted_weights,
|
||||
axes,
|
||||
name="sum_of_weights",
|
||||
keep_dims=True)
|
||||
|
||||
divisor = math_ops.inv(sum_of_weights, name="inv_weight_sum")
|
||||
|
||||
weighted_mean = math_ops.mul(weighted_input_sum, divisor)
|
||||
|
||||
# Have the weighted mean; now on to variance:
|
||||
weighted_distsq = math_ops.reduce_sum(
|
||||
frequency_weights * math_ops.squared_difference(x, weighted_mean),
|
||||
axes,
|
||||
name="weighted_distsq",
|
||||
keep_dims=True)
|
||||
|
||||
weighted_variance = math_ops.mul(weighted_distsq, divisor)
|
||||
|
||||
if not keep_dims:
|
||||
weighted_mean = array_ops.squeeze(weighted_mean, squeeze_dims=axes)
|
||||
weighted_variance = array_ops.squeeze(weighted_variance,
|
||||
squeeze_dims=axes)
|
||||
|
||||
if needs_cast:
|
||||
weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
|
||||
weighted_variance = math_ops.cast(weighted_variance, dtypes.float16)
|
||||
|
||||
return weighted_mean, weighted_variance
|
||||
|
||||
|
||||
def batch_normalization(x,
|
||||
mean,
|
||||
variance,
|
||||
|
@ -420,6 +420,16 @@ class NormalizeMomentsTest(tf.test.TestCase):
|
||||
|
||||
class MomentsTest(tf.test.TestCase):
|
||||
|
||||
def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
|
||||
# Method to compute moments of `x` wrt `axes`.
|
||||
#
|
||||
# This is exposed so WeightedMomentsTest can inherit the tests and
|
||||
# assertions from MomentsTest; the extra_out_grads argument allows
|
||||
# its inherited gradient tests to assert gradients against the
|
||||
# weights as well as the input values.
|
||||
|
||||
return tf.nn.moments(x, axes, keep_dims=keep_dims)
|
||||
|
||||
def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
|
||||
with self.test_session():
|
||||
# shape = [batch, width, height, depth]
|
||||
@ -428,7 +438,7 @@ class MomentsTest(tf.test.TestCase):
|
||||
x_numpy = np.random.normal(size=shape).astype(np.float32)
|
||||
x = tf.placeholder(dtype, shape=[None] * len(shape))
|
||||
|
||||
mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims)
|
||||
mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
|
||||
|
||||
num_elements = np.prod([shape[i] for i in axes])
|
||||
|
||||
@ -456,7 +466,11 @@ class MomentsTest(tf.test.TestCase):
|
||||
x_numpy = np.random.normal(size=shape).astype(np.float32)
|
||||
x = tf.cast(tf.constant(x_numpy), dtype=dtype)
|
||||
|
||||
mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims)
|
||||
# Compute the expected values at high precision since the method
|
||||
# is prone to catastrophic cancellation:
|
||||
x_numpy = x_numpy.astype(np.float128)
|
||||
|
||||
mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
|
||||
|
||||
num_elements = np.prod([shape[i] for i in axes])
|
||||
|
||||
@ -519,14 +533,21 @@ class MomentsTest(tf.test.TestCase):
|
||||
|
||||
axes = [0, 1, 2]
|
||||
y_shape = [2] # Depth of x
|
||||
out_mean, out_var = tf.nn.moments(x, axes)
|
||||
|
||||
inputs_to_compute_gradients_for = [x]
|
||||
|
||||
out_mean, out_var = self._unweighted_moments(
|
||||
x, axes, extra_out_grads=inputs_to_compute_gradients_for)
|
||||
if from_y == "mean":
|
||||
y = out_mean
|
||||
elif from_y == "var":
|
||||
y = out_var
|
||||
err = tf.test.compute_gradient_error(x, x_shape, y, y_shape)
|
||||
print("Moments %s gradient err = %g" % (from_y, err))
|
||||
self.assertLess(err, 1e-11)
|
||||
|
||||
for (i, v) in enumerate(inputs_to_compute_gradients_for):
|
||||
err = tf.test.compute_gradient_error(v, v.get_shape().as_list(),
|
||||
y, y_shape)
|
||||
print("Moments %s gradient err vs input %d = %g" % (from_y, i, err))
|
||||
self.assertLess(err, 1e-11)
|
||||
|
||||
def testMeanGlobalGradient(self):
|
||||
self._testGlobalGradient(from_y="mean")
|
||||
@ -534,19 +555,101 @@ class MomentsTest(tf.test.TestCase):
|
||||
def testVarGlobalGradient(self):
|
||||
self._testGlobalGradient(from_y="var")
|
||||
|
||||
def testOutputNamesNoKeep(self):
|
||||
"""Make sure the output names are stable."""
|
||||
with self.test_session():
|
||||
mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False)
|
||||
self.assertEquals(mean.op.name, "moments/normalize/mean")
|
||||
self.assertEquals(var.op.name, "moments/normalize/variance")
|
||||
|
||||
def testOutputNamesKeep(self):
|
||||
"""Make sure the output names are stable."""
|
||||
with self.test_session():
|
||||
mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True)
|
||||
self.assertEquals(mean.op.name, "moments/normalize/mean")
|
||||
self.assertEquals(var.op.name, "moments/normalize/variance")
|
||||
class WeightedMomentsTest(MomentsTest):
|
||||
"""Tests for nn.weighted_moments.
|
||||
|
||||
Note that this test inherits from MomentsTest, inheriting all its
|
||||
test methods!
|
||||
|
||||
It modifies MomentsTest in two ways:
|
||||
|
||||
a) By overriding _unweighted_moments, all the codepaths in
|
||||
MomentsTest are executed, but with calls to tf.nn.moments()
|
||||
replaced by calls to tf.nn.weighted_moments() with a constant
|
||||
weight of 1.
|
||||
|
||||
b) By overriding RunMomentTest and RunMomentTestWithDynamicShape,
|
||||
this test adds multiple additional calls to
|
||||
RunWeightedMomentsTest() to exercise correctness with
|
||||
non-constant weights and varying broadcasting situations. (It
|
||||
also continues to call MomentsTest.Run(Weighted)?MomentsTest as
|
||||
well.)
|
||||
|
||||
"""
|
||||
|
||||
def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
|
||||
weights = tf.constant(1, dtype=x.dtype)
|
||||
if extra_out_grads is not None:
|
||||
# We want to assert gradients WRT weights as well as X!
|
||||
extra_out_grads.append(weights)
|
||||
return tf.nn.weighted_moments(
|
||||
x, axes, weights, keep_dims=keep_dims)
|
||||
|
||||
def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False):
|
||||
if not dynshapes:
|
||||
super(WeightedMomentsTest, self).RunMomentTest(
|
||||
shape, axes, keep_dims, dtype)
|
||||
else:
|
||||
super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape(
|
||||
shape, axes, keep_dims, dtype)
|
||||
|
||||
# 1:1 weights and inputs
|
||||
self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype)
|
||||
|
||||
# Various broadcasting combinations
|
||||
for idx in range(len(shape)):
|
||||
# try broadcasting weights in all positions
|
||||
weight_shape = [1] * len(shape)
|
||||
weight_shape[idx] = shape[idx]
|
||||
|
||||
self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype)
|
||||
|
||||
# Also try broadcasting with a suffix of length n
|
||||
weight_shape = shape[-(idx+1):]
|
||||
self.RunWeightedMomentTest(
|
||||
shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes)
|
||||
|
||||
def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
|
||||
self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True)
|
||||
|
||||
def RunWeightedMomentTest(
|
||||
self, shape, weights_shape, axes, keep_dims, dtype, dynshapes=False):
|
||||
with self.test_session() as s:
|
||||
x_numpy = np.random.normal(size=shape).astype(np.float32)
|
||||
weights_numpy = np.absolute( # weights must be positive
|
||||
np.random.normal(size=weights_shape, loc=1.0).astype(np.float32))
|
||||
|
||||
# Expand the numpy version to higher precision
|
||||
x_numpy = x_numpy.astype(np.float128)
|
||||
weights_numpy = weights_numpy.astype(np.float128)
|
||||
|
||||
x_shape = [None] * len(shape) if dynshapes else shape
|
||||
weights_shape = (
|
||||
[None] * len(weights_shape) if dynshapes else weights_shape)
|
||||
|
||||
x = tf.placeholder(dtype, shape=x_shape)
|
||||
weights = tf.placeholder(dtype, shape=weights_shape)
|
||||
|
||||
mean, var = tf.nn.weighted_moments(x, axes, weights, keep_dims=keep_dims)
|
||||
|
||||
ax = tuple(axes)
|
||||
|
||||
def _np_weighted_sum(v):
|
||||
return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims)
|
||||
|
||||
weight_sum = _np_weighted_sum(np.ones_like(x_numpy))
|
||||
expected_mean = _np_weighted_sum(x_numpy) / weight_sum
|
||||
expected_mean_squared = np.multiply(expected_mean, expected_mean)
|
||||
expected_x_squared = (
|
||||
_np_weighted_sum(np.multiply(x_numpy, x_numpy)) / weight_sum)
|
||||
expected_variance = expected_x_squared - expected_mean_squared
|
||||
|
||||
mean_v, var_v = s.run([mean, var],
|
||||
feed_dict={x: x_numpy, weights: weights_numpy})
|
||||
|
||||
self.assertAllCloseAccordingToType(expected_mean, mean_v)
|
||||
self.assertAllCloseAccordingToType(expected_variance, var_v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user