diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 0ed396e4532..dc4ee9226a4 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -123,6 +123,7 @@ def batch_norm(inputs,
                variables_collections=None,
                outputs_collections=None,
                trainable=True,
+               batch_weights=None,
                scope=None):
   """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
 
@@ -171,6 +172,11 @@ def batch_norm(inputs,
     outputs_collections: collections to add the outputs.
     trainable: If `True` also add variables to the graph collection
       `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+    batch_weights: An optional tensor of shape `[batch_size]`,
+      containing a frequency weight for each batch item. If present,
+      then the batch normalization uses weighted mean and
+      variance. (This can be used to correct for bias in training
+      example selection.)
     scope: Optional scope for `variable_scope`.
 
   Returns:
@@ -187,6 +193,14 @@ def batch_norm(inputs,
     if inputs_rank is None:
       raise ValueError('Inputs %s has undefined rank.' % inputs.name)
     dtype = inputs.dtype.base_dtype
+    if batch_weights is not None:
+      batch_weights = ops.convert_to_tensor(batch_weights)
+      inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape())
+
+      # Reshape batch weight values so they broadcast across inputs.
+      nshape = [-1] + [1 for _ in range(inputs_rank - 1)]
+      batch_weights = array_ops.reshape(batch_weights, nshape)
+
     axis = list(range(inputs_rank - 1))
     params_shape = inputs_shape[-1:]
     if not params_shape.is_fully_defined():
@@ -240,9 +254,13 @@ def batch_norm(inputs,
     need_moments = is_training_value is None or is_training_value
     if need_moments:
       # Calculate the moments based on the individual batch.
-      # Use a copy of moving_mean as a shift to compute more reliable moments.
-      shift = math_ops.add(moving_mean, 0)
-      mean, variance = nn.moments(inputs, axis, shift=shift)
+      if batch_weights is None:
+        # Use a copy of moving_mean as a shift to compute more reliable moments.
+        shift = math_ops.add(moving_mean, 0)
+        mean, variance = nn.moments(inputs, axis, shift=shift)
+      else:
+        mean, variance = nn.weighted_moments(inputs, axis, batch_weights)
+
       moving_vars_fn = lambda: (moving_mean, moving_variance)
       if updates_collections is None:
         def _force_updates():
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 958c32f0fc8..992e0f6f791 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -188,6 +188,7 @@ have varying scale, and to aid generalization.
 @@sufficient_statistics
 @@normalize_moments
 @@moments
+@@weighted_moments
 
 ## Losses
 
@@ -819,7 +820,7 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
 
   Args:
     x: A `Tensor`.
-    axes: array of ints.  Axes along which to compute mean and
+    axes: Array of ints.  Axes along which to compute mean and
       variance.
     shift: A `Tensor` containing the value by which to shift the data for
       numerical stability, or `None` if no shift is to be performed. A shift
@@ -848,6 +849,82 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
         return (mean, variance)
 
 
+def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
+  """Returns the frequency-weighted mean and variance of `x`.
+
+  Args:
+    x: A tensor.
+    axes: 1-d tensor of int32 values; these are the axes along which
+      to compute mean and variance.
+    frequency_weights: A tensor of positive weights which can be
+      broadcast with x.
+    name: Name used to scope the operation.
+    keep_dims: Produce moments with the same dimensionality as the input.
+
+  Returns:
+    Two tensors: `weighted_mean` and `weighted_variance`.
+  """
+  with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]):
+    x = ops.convert_to_tensor(x, name="x")
+    frequency_weights = ops.convert_to_tensor(
+        frequency_weights, name="frequency_weights")
+
+    # Unlike moments(), this just uses a simpler two-pass method.
+
+    # See comment in moments() WRT precision; it applies here too.
+    needs_cast = x.dtype == dtypes.float16
+    if needs_cast:
+      x = math_ops.cast(x, dtypes.float32)
+
+    if frequency_weights.dtype != x.dtype:
+      frequency_weights = math_ops.cast(frequency_weights, x.dtype)
+
+    # Note that we use keep_dims=True for our reductions regardless of the arg;
+    # this is so that the results remain broadcast-compatible with the inputs.
+    weighted_input_sum = math_ops.reduce_sum(frequency_weights * x,
+                                             axes,
+                                             name="weighted_input_sum",
+                                             keep_dims=True)
+
+    # The shape of the weights isn't necessarily the same as x's
+    # shape, just broadcast-compatible with it -- so this expression
+    # performs broadcasting to give a per-item weight, with the same
+    # shape as (freqency_weights * x). This avoids having to reason
+    # through all the broadcast logic to compute a correct
+    # sum_of_weights.
+    broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
+
+    sum_of_weights = math_ops.reduce_sum(
+        broadcasted_weights,
+        axes,
+        name="sum_of_weights",
+        keep_dims=True)
+
+    divisor = math_ops.inv(sum_of_weights, name="inv_weight_sum")
+
+    weighted_mean = math_ops.mul(weighted_input_sum, divisor)
+
+    # Have the weighted mean; now on to variance:
+    weighted_distsq = math_ops.reduce_sum(
+        frequency_weights * math_ops.squared_difference(x, weighted_mean),
+        axes,
+        name="weighted_distsq",
+        keep_dims=True)
+
+    weighted_variance = math_ops.mul(weighted_distsq, divisor)
+
+    if not keep_dims:
+      weighted_mean = array_ops.squeeze(weighted_mean, squeeze_dims=axes)
+      weighted_variance = array_ops.squeeze(weighted_variance,
+                                            squeeze_dims=axes)
+
+    if needs_cast:
+      weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
+      weighted_variance = math_ops.cast(weighted_variance, dtypes.float16)
+
+    return weighted_mean, weighted_variance
+
+
 def batch_normalization(x,
                         mean,
                         variance,
diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py
index 9ccf331c484..5e928fba569 100644
--- a/tensorflow/python/ops/nn_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_batchnorm_test.py
@@ -420,6 +420,16 @@ class NormalizeMomentsTest(tf.test.TestCase):
 
 class MomentsTest(tf.test.TestCase):
 
+  def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
+    # Method to compute moments of `x` wrt `axes`.
+    #
+    # This is exposed so WeightedMomentsTest can inherit the tests and
+    # assertions from MomentsTest; the extra_out_grads argument allows
+    # its inherited gradient tests to assert gradients against the
+    # weights as well as the input values.
+
+    return tf.nn.moments(x, axes, keep_dims=keep_dims)
+
   def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
     with self.test_session():
       # shape = [batch, width, height, depth]
@@ -428,7 +438,7 @@ class MomentsTest(tf.test.TestCase):
       x_numpy = np.random.normal(size=shape).astype(np.float32)
       x = tf.placeholder(dtype, shape=[None] * len(shape))
 
-      mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims)
+      mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
 
       num_elements = np.prod([shape[i] for i in axes])
 
@@ -456,7 +466,11 @@ class MomentsTest(tf.test.TestCase):
       x_numpy = np.random.normal(size=shape).astype(np.float32)
       x = tf.cast(tf.constant(x_numpy), dtype=dtype)
 
-      mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims)
+      # Compute the expected values at high precision since the method
+      # is prone to catastrophic cancellation:
+      x_numpy = x_numpy.astype(np.float128)
+
+      mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
 
       num_elements = np.prod([shape[i] for i in axes])
 
@@ -519,14 +533,21 @@ class MomentsTest(tf.test.TestCase):
 
       axes = [0, 1, 2]
       y_shape = [2]  # Depth of x
-      out_mean, out_var = tf.nn.moments(x, axes)
+
+      inputs_to_compute_gradients_for = [x]
+
+      out_mean, out_var = self._unweighted_moments(
+          x, axes, extra_out_grads=inputs_to_compute_gradients_for)
       if from_y == "mean":
         y = out_mean
       elif from_y == "var":
         y = out_var
-      err = tf.test.compute_gradient_error(x, x_shape, y, y_shape)
-      print("Moments %s gradient err = %g" % (from_y, err))
-      self.assertLess(err, 1e-11)
+
+      for (i, v) in enumerate(inputs_to_compute_gradients_for):
+        err = tf.test.compute_gradient_error(v, v.get_shape().as_list(),
+                                             y, y_shape)
+        print("Moments %s gradient err vs input %d = %g" % (from_y, i, err))
+        self.assertLess(err, 1e-11)
 
   def testMeanGlobalGradient(self):
     self._testGlobalGradient(from_y="mean")
@@ -534,19 +555,101 @@ class MomentsTest(tf.test.TestCase):
   def testVarGlobalGradient(self):
     self._testGlobalGradient(from_y="var")
 
-  def testOutputNamesNoKeep(self):
-    """Make sure the output names are stable."""
-    with self.test_session():
-      mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False)
-      self.assertEquals(mean.op.name, "moments/normalize/mean")
-      self.assertEquals(var.op.name, "moments/normalize/variance")
 
-  def testOutputNamesKeep(self):
-    """Make sure the output names are stable."""
-    with self.test_session():
-      mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True)
-      self.assertEquals(mean.op.name, "moments/normalize/mean")
-      self.assertEquals(var.op.name, "moments/normalize/variance")
+class WeightedMomentsTest(MomentsTest):
+  """Tests for nn.weighted_moments.
+
+  Note that this test inherits from MomentsTest, inheriting all its
+  test methods!
+
+  It modifies MomentsTest in two ways:
+
+  a) By overriding _unweighted_moments, all the codepaths in
+     MomentsTest are executed, but with calls to tf.nn.moments()
+     replaced by calls to tf.nn.weighted_moments() with a constant
+     weight of 1.
+
+  b) By overriding RunMomentTest and RunMomentTestWithDynamicShape,
+     this test adds multiple additional calls to
+     RunWeightedMomentsTest() to exercise correctness with
+     non-constant weights and varying broadcasting situations. (It
+     also continues to call MomentsTest.Run(Weighted)?MomentsTest as
+     well.)
+
+  """
+
+  def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
+    weights = tf.constant(1, dtype=x.dtype)
+    if extra_out_grads is not None:
+      # We want to assert gradients WRT weights as well as X!
+      extra_out_grads.append(weights)
+    return tf.nn.weighted_moments(
+        x, axes, weights, keep_dims=keep_dims)
+
+  def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False):
+    if not dynshapes:
+      super(WeightedMomentsTest, self).RunMomentTest(
+          shape, axes, keep_dims, dtype)
+    else:
+      super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape(
+          shape, axes, keep_dims, dtype)
+
+    # 1:1 weights and inputs
+    self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype)
+
+    # Various broadcasting combinations
+    for idx in range(len(shape)):
+      # try broadcasting weights in all positions
+      weight_shape = [1] * len(shape)
+      weight_shape[idx] = shape[idx]
+
+      self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype)
+
+      # Also try broadcasting with a suffix of length n
+      weight_shape = shape[-(idx+1):]
+      self.RunWeightedMomentTest(
+          shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes)
+
+  def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
+    self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True)
+
+  def RunWeightedMomentTest(
+      self, shape, weights_shape, axes, keep_dims, dtype, dynshapes=False):
+    with self.test_session() as s:
+      x_numpy = np.random.normal(size=shape).astype(np.float32)
+      weights_numpy = np.absolute(  # weights must be positive
+          np.random.normal(size=weights_shape, loc=1.0).astype(np.float32))
+
+      # Expand the numpy version to higher precision
+      x_numpy = x_numpy.astype(np.float128)
+      weights_numpy = weights_numpy.astype(np.float128)
+
+      x_shape = [None] * len(shape) if dynshapes else shape
+      weights_shape = (
+          [None] * len(weights_shape) if dynshapes else weights_shape)
+
+      x = tf.placeholder(dtype, shape=x_shape)
+      weights = tf.placeholder(dtype, shape=weights_shape)
+
+      mean, var = tf.nn.weighted_moments(x, axes, weights, keep_dims=keep_dims)
+
+      ax = tuple(axes)
+
+      def _np_weighted_sum(v):
+        return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims)
+
+      weight_sum = _np_weighted_sum(np.ones_like(x_numpy))
+      expected_mean = _np_weighted_sum(x_numpy) / weight_sum
+      expected_mean_squared = np.multiply(expected_mean, expected_mean)
+      expected_x_squared = (
+          _np_weighted_sum(np.multiply(x_numpy, x_numpy)) / weight_sum)
+      expected_variance = expected_x_squared - expected_mean_squared
+
+      mean_v, var_v = s.run([mean, var],
+                            feed_dict={x: x_numpy, weights: weights_numpy})
+
+      self.assertAllCloseAccordingToType(expected_mean, mean_v)
+      self.assertAllCloseAccordingToType(expected_variance, var_v)
 
 
 if __name__ == "__main__":