diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py
index f0ed31d1d1d..e739daa2fcd 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py
@@ -26,15 +26,18 @@ from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import embedding_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import sparse_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import tf_logging as logging
 
 __all__ = [
     "safe_embedding_lookup_sparse", "scattered_embedding_lookup",
-    "scattered_embedding_lookup_sparse", "embedding_lookup_unique"
+    "scattered_embedding_lookup_sparse", "embedding_lookup_unique",
+    "embedding_lookup_sparse_with_distributed_aggregation"
 ]
 
 
@@ -548,3 +551,326 @@ def _sampled_scattered_embedding_lookup_sparse(params,
     return math_ops.unsorted_segment_sum(embeddings, segment_ids,
                                          num_segments=num_segments,
                                          name=name_scope)
+
+
+def embedding_lookup_sparse_with_distributed_aggregation(params, sp_ids,
+    sp_weights, partition_strategy="mod", name=None, combiner=None,
+    max_norm=None):
+  """Computes embeddings for the given ids and weights.
+
+  Embeddings belonging to same param are aggregated on that device first. This
+  op is intended to decrease data transmission and improve parallelism. See
+  `tf.nn.embedding_lookup_sparse` for the functionality and example of this op.
+
+  Args:
+    params: A single tensor representing the complete embedding tensor,
+      or a list of P tensors all of same shape except for the first dimension,
+      representing sharded embedding tensors.  Alternatively, a
+      `PartitionedVariable`, created by partitioning along dimension 0. Each
+      element must be appropriately sized for the given `partition_strategy`.
+    sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
+      where N is typically batch size and M is arbitrary.
+    sp_weights: either a SparseTensor of float / double weights, or None to
+      indicate all weights should be taken to be 1. If specified, sp_weights
+      must have exactly the same shape and indices as sp_ids.
+    partition_strategy: A string specifying the partitioning strategy, relevant
+      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
+      is `"mod"`. See `tf.nn.embedding_lookup` for more details.
+    name: Optional name for the op.
+    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
+      and "sum" are supported.
+      "sum" computes the weighted sum of the embedding results for each row.
+      "mean" is the weighted sum divided by the total weight.
+      "sqrtn" is the weighted sum divided by the square root of the sum of the
+      squares of the weights.
+    max_norm: If not None, each embedding is normalized to have l2 norm equal
+      to max_norm before combining.
+
+  Returns:
+    A dense tensor representing the combined embeddings for the
+    sparse ids. For each row in the dense tensor represented by sp_ids, the op
+    looks up the embeddings for all ids in that row, multiplies them by the
+    corresponding weight, and combines these embeddings as specified.
+
+  Raises:
+    TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
+      None nor SparseTensor.
+    ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
+  """
+  if combiner is None:
+    logging.warn("The default value of combiner will change from \"mean\" "
+                 "to \"sqrtn\" after 2016/11/01.")
+    combiner = "mean"
+  if combiner not in ("mean", "sqrtn", "sum"):
+    raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
+  if isinstance(params, variables.PartitionedVariable):
+    params = list(params)  # Iterate to get the underlying Variables.
+  if not isinstance(params, list):
+    params = [params]
+  if not isinstance(sp_ids, sparse_tensor.SparseTensor):
+    raise TypeError("sp_ids must be SparseTensor")
+  ignore_weights = sp_weights is None
+  if not ignore_weights:
+    if not isinstance(sp_weights, sparse_tensor.SparseTensor):
+      raise TypeError("sp_weights must be either None or SparseTensor")
+    sp_ids.values.get_shape().assert_is_compatible_with(
+        sp_weights.values.get_shape())
+    sp_ids.indices.get_shape().assert_is_compatible_with(
+        sp_weights.indices.get_shape())
+    sp_ids.dense_shape.get_shape().assert_is_compatible_with(
+        sp_weights.dense_shape.get_shape())
+    # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
+    # sp_weights have equal indices and shapes.
+
+  with ops.name_scope(name, "embedding_lookup_sparse",
+                      params + [sp_ids]) as name:
+    segment_ids = sp_ids.indices[:, 0]
+    if segment_ids.dtype != dtypes.int32:
+      segment_ids = math_ops.cast(segment_ids, dtypes.int32)
+
+    ids = sp_ids.values
+    if ignore_weights:
+      ids, idx = array_ops.unique(ids)
+    else:
+      idx = None
+
+    weights = None if ignore_weights else sp_weights.values
+    embeddings = _embedding_lookup_with_distributed_aggregation(
+        params, ids, partition_strategy=partition_strategy, max_norm=max_norm,
+        weights=weights, idx=idx, segment_ids=segment_ids)
+    # Set weights to all one if ignore weights.
+    if ignore_weights:
+      weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1)
+    if weights.dtype != embeddings.dtype:
+      weights = math_ops.cast(weights, embeddings.dtype)
+    # Reshape weights.
+    ones = array_ops.fill(
+        array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
+    bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
+                                           0)
+    orig_weights_shape = weights.get_shape()
+    weights = array_ops.reshape(weights, bcast_weights_shape)
+    if embeddings.get_shape().ndims is not None:
+      weights.set_shape(orig_weights_shape.concatenate(
+          [1 for _ in range(embeddings.get_shape().ndims - 1)]))
+
+    if combiner == "mean":
+      weight_sum = math_ops.segment_sum(weights, segment_ids)
+      embeddings = math_ops.div(embeddings, weight_sum)
+    elif combiner == "sqrtn":
+      weights_squared = math_ops.pow(weights, 2)
+      weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
+      weight_sum_sqrt = math_ops.sqrt(weight_sum)
+      embeddings = math_ops.div(embeddings, weight_sum_sqrt)
+    elif combiner != "sum":
+      assert False, "Unrecognized combiner"
+    return embeddings
+
+
+def _do_gather(params, ids, validate_indices=True, name=None):
+  """Deals with doing gather differently for resource variables."""
+  if isinstance(params, resource_variable_ops.ResourceVariable):
+    return params.sparse_read(ids, name=name)
+  return array_ops.gather(
+      params, ids, name=name, validate_indices=validate_indices)
+
+
+def _embedding_lookup_with_distributed_aggregation(params, ids,
+    partition_strategy="mod", name=None, validate_indices=True, max_norm=None,
+    weights=None, idx=None, segment_ids=None):
+  """ Lookup helper for embedding_lookup_sparse_with_distributed_aggregation."""
+  if params is None or params == []:  # pylint: disable=g-explicit-bool-comparison
+    raise ValueError("Need at least one param")
+  if isinstance(params, variables.PartitionedVariable):
+    params = list(params)  # Iterate to get the underlying Variables.
+  if not isinstance(params, list):
+    params = [params]
+  def maybe_normalize(x):
+    if max_norm is not None:
+      if x.get_shape().ndims is not None:
+        ndims = x.get_shape().ndims
+      else:
+        ndims = array_ops.size(array_ops.shape(x))
+      return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
+    return x
+  with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation",
+      params + [ids]) as name:
+    np = len(params)  # Number of partitions
+    # Preserve the resource variable status to avoid accidental dense reads.
+    if not any(isinstance(p, resource_variable_ops.ResourceVariable)
+               for p in params):
+      params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
+    if np == 1:
+      with ops.colocate_with(params[0]):
+        ret = maybe_normalize(
+            _do_gather(
+                params[0], ids, validate_indices=validate_indices))
+        ignore_weights = weights is None
+        if not ignore_weights:
+          if weights.dtype != ret.dtype:
+            weights = math_ops.cast(weights, ret.dtype)
+          # Reshape to allow broadcast
+          ones = array_ops.fill(
+              array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1)
+          bcast_weights_shape = array_ops.concat(
+              [array_ops.shape(weights), ones], 0)
+          orig_weights_shape = weights.get_shape()
+          weights = array_ops.reshape(weights, bcast_weights_shape)
+          # Set weights shape after reshape
+          if ret.get_shape().ndims is not None:
+            weights.set_shape(orig_weights_shape.concatenate(
+                [1 for _ in range(ret.get_shape().ndims - 1)]))
+          ret *= weights
+          return math_ops.segment_sum(ret, segment_ids, name=name)
+        else:
+          return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name)
+    else:
+      ids = ops.convert_to_tensor(ids, name="ids")
+      flat_ids = array_ops.reshape(ids, [-1])
+      original_indices = math_ops.range(array_ops.size(flat_ids))
+
+      # Create p_assignments and set new_ids depending on the strategy.
+      if partition_strategy == "mod":
+        p_assignments = flat_ids % np
+        new_ids = flat_ids // np
+      elif partition_strategy == "div":
+        # Compute num_total_ids as the sum of dim-0 of params, then assign to
+        # partitions based on a constant number of ids per partition. Optimize
+        # if we already know the full shape statically.
+        dim_0_size = params[0].get_shape()[0]
+        for p in xrange(1, np):
+          dim_0_size += params[p].get_shape()[0]
+        if dim_0_size.value:
+          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
+        else:
+          dim_0_sizes = []
+          for p in xrange(np):
+            if params[p].get_shape()[0].value is not None:
+              dim_0_sizes.append(params[p].get_shape()[0].value)
+            else:
+              with ops.colocate_with(params[p]):
+                dim_0_sizes.append(array_ops.shape(params[p])[0])
+          num_total_ids = math_ops.reduce_sum(
+              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
+        ids_per_partition = num_total_ids // np
+        extras = num_total_ids % np
+
+        p_assignments = math_ops.maximum(
+            flat_ids // (ids_per_partition + 1),
+            (flat_ids - extras) // ids_per_partition)
+
+        # Emulate a conditional using a boolean indicator tensor
+        is_in_first_extras_partitions = math_ops.cast(
+            p_assignments < extras, flat_ids.dtype)
+        new_ids = (
+            is_in_first_extras_partitions * (
+                flat_ids % (ids_per_partition + 1)) +
+            (1 - is_in_first_extras_partitions) * (
+                (flat_ids - extras) % ids_per_partition))
+      else:
+        raise ValueError("Unrecognized partition strategy: " +
+                         partition_strategy)
+
+      # Cast partition assignments to int32 for use in dynamic_partition.
+      # There really should not be more than 2^32 partitions.
+      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
+      # Partition list of ids based on assignments into np separate lists
+      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
+      # Similarly, partition the original indices.
+      pindices = data_flow_ops.dynamic_partition(original_indices,
+                                                 p_assignments, np)
+      # Do np separate lookups, finding embeddings for plist[p] in params[p]
+      partitioned_result = []
+      for p in xrange(np):
+        with ops.colocate_with(params[p]):
+          partitioned_result.append(
+              _do_gather(params[p], gather_ids[p],
+                         validate_indices=validate_indices))
+
+      ignore_weights = weights is None
+      if not ignore_weights:
+        # Partition weights according to pindices.
+        partitioned_weight = []
+        for p in xrange(np):
+          partitioned_weight.append(array_ops.gather(weights, pindices[p]))
+      # Reshape each partition result.
+      element_shape = params[0].get_shape()[1:]
+      for p in params[1:]:
+        element_shape = element_shape.merge_with(p.get_shape()[1:])
+      if element_shape.is_fully_defined():
+        for p in xrange(np):
+          with ops.colocate_with(params[p]):
+            partitioned_result[p] = array_ops.reshape(partitioned_result[p],
+                array_ops.concat(
+                    [array_ops.shape(pindices[p]), element_shape], 0))
+      else:
+        with ops.colocate_with(params[0]):
+          params_shape = array_ops.shape(params[0])
+        for p in xrange(np):
+          with ops.colocate_with(params[p]):
+            partitioned_result[p] = array_ops.reshape(partitioned_result[p],
+                array_ops.concat([array_ops.shape(pindices[p]),
+                    array_ops.slice(params_shape, [1], [-1])], 0))
+      # Normalize each partition result.
+      for p in xrange(np):
+        with ops.colocate_with(params[p]):
+          partitioned_result[p] = maybe_normalize(partitioned_result[p])
+      if not ignore_weights:
+        # Multiply each partition result with partition weights.
+        for p in xrange(np):
+          with ops.colocate_with(params[p]):
+            if partitioned_weight[p].dtype != partitioned_result[p].dtype:
+              partitioned_weight[p] = math_ops.cast(partitioned_weight[p],
+                  partitioned_result[p].dtype)
+            # Reshape partition weights.
+            ones = array_ops.fill(
+                array_ops.expand_dims(
+                    array_ops.rank(partitioned_result[p]) - 1, 0), 1)
+            bcast_weights_shape = array_ops.concat(
+                [array_ops.shape(partitioned_weight[p]), ones], 0)
+            orig_weights_shape = partitioned_weight[p].get_shape()
+            partitioned_weight[p] = array_ops.reshape(partitioned_weight[p],
+                                                      bcast_weights_shape)
+            if partitioned_result[p].get_shape().ndims is not None:
+              partitioned_weight[p].set_shape(orig_weights_shape.concatenate(
+                  [1 for _ in range(
+                      partitioned_result[p].get_shape().ndims - 1)]))
+            partitioned_result[p] *= partitioned_weight[p]
+      partitioned_segment_ids = []
+      for p in xrange(np):
+        if not ignore_weights:
+          # Partition segment_ids according to pindices.
+          p_segment_ids = array_ops.gather(segment_ids, pindices[p])
+          # Number the p_segment_ids to meet segment_sum's requirements. Note
+          # that unique_p_segment_ids contains unique segment ids of this
+          # partiton and these ids' order is unchanged.
+          unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
+              p_segment_ids)
+          partitioned_segment_ids.append(unique_p_segment_ids)
+          # segment_sum this partition's result.
+          with ops.colocate_with(params[p]):
+            partitioned_result[p] = math_ops.segment_sum(
+                partitioned_result[p], unique_p_segment_idx)
+        else:
+          # When ignore weights, we need to get indexs of elements in idx and
+          # segment_ids.
+          _, exclude_idx = array_ops.setdiff1d(idx, pindices[p])
+          all_idx = math_ops.range(array_ops.shape(idx)[0])
+          _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx)
+          # Gather segment_ids and idx according to indexs.
+          p_segment_ids = array_ops.gather(segment_ids, include_idx)
+          p_idx = array_ops.gather(idx, include_idx)
+          # Number the p_segment_ids, same as ignore_weights case above.
+          unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
+              p_segment_ids)
+          _, unique_p_idx_idx = array_ops.unique(p_idx)
+          partitioned_segment_ids.append(unique_p_segment_ids)
+          with ops.colocate_with(params[p]):
+            partitioned_result[p] = math_ops.sparse_segment_sum(
+                partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx)
+      # Concat each partition's segment_ids and result for final segment_sum.
+      concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0)
+      concat_partitioned_result = array_ops.concat(partitioned_result, 0)
+      return math_ops.unsorted_segment_sum(
+          concat_partitioned_result, concat_segment_ids,
+          math_ops.reduce_max(concat_segment_ids) + 1, name=name)
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
index dfa8067f27a..eb38d70c52c 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
@@ -32,9 +32,11 @@ from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
 from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import gradient_checker
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import partitioned_variables
 from tensorflow.python.platform import test
+from tensorflow.python.util import compat
 
 
 class SafeEmbeddingLookupSparseTest(test.TestCase):
@@ -563,5 +565,224 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
       self.assertAllClose(result.eval(), result_abc.eval())
 
 
+def _PName(param_id):
+  return "p" + str(param_id)
+
+
+def _EmbeddingParams(num_shards,
+                     vocab_size,
+                     dtype=dtypes.float32,
+                     shape=None,
+                     use_shapeless_placeholder=False):
+  p = []
+  params = {}
+  feed_dict = {}
+  if not shape:
+    shape = [10]
+  for i in range(num_shards):
+    shard_shape = [vocab_size // num_shards] + shape
+    if i < vocab_size % num_shards:  # Excess goes evenly on the first shards
+      shard_shape[0] += 1
+
+    param_name = _PName(i)
+
+    if use_shapeless_placeholder:
+      param = array_ops.placeholder(dtype, shape=None, name=param_name)
+    else:
+      param = constant_op.constant(
+          1.0, shape=shard_shape, dtype=dtype, name=param_name)
+    p.append(param)
+    np_type = "f" if dtype == dtypes.float32 else "d"
+    val = (np.random.rand(*shard_shape).astype(np_type)) + 1
+    params[param_name + ":0"] = val
+    feed_dict[param.name] = val
+  return p, params, feed_dict
+
+
+def _EmbeddingResult(params,
+                     id_vals,
+                     num_shards,
+                     vocab_size,
+                     partition_strategy="mod",
+                     weight_vals=None):
+  if weight_vals is None:
+    weight_vals = np.copy(id_vals)
+    weight_vals.fill(1)
+  values = []
+  weights = []
+  weights_squared = []
+  for ids, wts in zip(id_vals, weight_vals):
+    value_aggregation = None
+    weight_aggregation = None
+    squared_weight_aggregation = None
+    if isinstance(ids, compat.integral_types):
+      ids = [ids]
+      wts = [wts]
+    for i, weight_value in zip(ids, wts):
+      if partition_strategy == "mod":
+        val = np.copy(params[_PName(i % num_shards) + ":0"][
+            i // num_shards, :]) * weight_value
+      elif partition_strategy == "div":
+        ids_per_partition, extras = divmod(vocab_size, num_shards)
+        threshold = extras * (ids_per_partition + 1)
+        if i < threshold:
+          partition = i // (ids_per_partition + 1)
+          offset = i % (ids_per_partition + 1)
+        else:
+          partition = extras + (i - threshold) // ids_per_partition
+          offset = (i - threshold) % ids_per_partition
+        val = np.copy(params[_PName(partition) + ":0"][
+            offset, :]) * weight_value
+      else:
+        assert False
+      if value_aggregation is None:
+        assert weight_aggregation is None
+        assert squared_weight_aggregation is None
+        value_aggregation = val
+        weight_aggregation = weight_value
+        squared_weight_aggregation = weight_value * weight_value
+      else:
+        assert weight_aggregation is not None
+        assert squared_weight_aggregation is not None
+        value_aggregation += val
+        weight_aggregation += weight_value
+        squared_weight_aggregation += weight_value * weight_value
+    values.append(value_aggregation)
+    weights.append(weight_aggregation)
+    weights_squared.append(squared_weight_aggregation)
+  values = np.array(values).astype(np.float32)
+  weights = np.array(weights).astype(np.float32)
+  weights_squared = np.array(weights_squared).astype(np.float32)
+  return values, weights, weights_squared
+
+
+class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
+
+  def _RandomIdsAndWeights(self, batch_size, vocab_size):
+    max_val_per_entry = 6
+    vals_per_batch_entry = np.random.randint(
+        1, max_val_per_entry, size=batch_size)
+    num_vals = np.sum(vals_per_batch_entry)
+
+    ids = np.random.randint(vocab_size, size=num_vals)
+    weights = 1 + np.random.rand(num_vals)
+
+    indices = []
+    for batch_entry, num_val in enumerate(vals_per_batch_entry):
+      for val_index in range(num_val):
+        indices.append([batch_entry, val_index])
+
+    shape = [batch_size, max_val_per_entry]
+
+    sp_ids = sparse_tensor_lib.SparseTensor(
+        constant_op.constant(indices, dtypes.int64),
+        constant_op.constant(ids, dtypes.int32),
+        constant_op.constant(shape, dtypes.int64))
+    sp_weights = sparse_tensor_lib.SparseTensor(
+        constant_op.constant(indices, dtypes.int64),
+        constant_op.constant(weights, dtypes.float32),
+        constant_op.constant(shape, dtypes.int64))
+
+    return sp_ids, sp_weights, ids, weights, vals_per_batch_entry
+
+  def _GroupByBatchEntry(self, vals, vals_per_batch_entry):
+    grouped_vals = []
+    index = 0
+    for num_val in vals_per_batch_entry:
+      grouped_vals.append(list(vals[index:(index + num_val)]))
+      index += num_val
+    return grouped_vals
+
+  def testEmbeddingLookupSparse(self):
+    vocab_size = 13
+    batch_size = 10
+    param_shape = [2, 5]
+    expected_lookup_result_shape = [None] + param_shape
+
+    sp_ids, sp_weights, ids, weights, vals_per_batch_entry = (
+        self._RandomIdsAndWeights(batch_size, vocab_size))
+
+    grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry)
+    grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry)
+    grouped_ignored_weights = self._GroupByBatchEntry(
+        np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
+
+    for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 5],
+        ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
+        [True, False]):
+
+      with self.test_session():
+        p, params, feed_dict = _EmbeddingParams(
+            num_shards, vocab_size, shape=param_shape, dtype=dtype)
+        embedding_sum = \
+            embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
+            p,
+            sp_ids,
+            None if ignore_weights else sp_weights,
+            combiner=combiner)
+
+        self.assertEqual(embedding_sum.get_shape().as_list(),
+                         expected_lookup_result_shape)
+
+        tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
+
+        np_embedding_sum, np_weight_sum, np_weight_sq_sum = _EmbeddingResult(
+            params,
+            grouped_ids,
+            num_shards,
+            vocab_size,
+            weight_vals=grouped_ignored_weights if ignore_weights else
+            grouped_weights)
+        if combiner == "mean":
+          np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1))
+        if combiner == "sqrtn":
+          np_embedding_sum /= np.reshape(
+              np.sqrt(np_weight_sq_sum), (batch_size, 1, 1))
+        self.assertAllClose(np_embedding_sum, tf_embedding_sum)
+
+  def testGradientsEmbeddingLookupSparse(self):
+    vocab_size = 12
+    batch_size = 4
+    param_shape = [2, 3]
+    sp_ids, sp_weights, _, _, _ = (
+        self._RandomIdsAndWeights(batch_size, vocab_size))
+
+    for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 3],
+        ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
+        [True, False]):
+      with self.test_session():
+        x, params, _ = _EmbeddingParams(
+            num_shards, vocab_size, shape=param_shape, dtype=dtype)
+
+        y = embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
+            x,
+            sp_ids,
+            None if ignore_weights else sp_weights,
+            combiner=combiner)
+        x_name = [_PName(i) for i in range(num_shards)]
+        x_init_value = [params[x_n + ":0"] for x_n in x_name]
+        x_shape = [i.shape for i in x_init_value]
+        y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:])
+        err = gradient_checker.compute_gradient_error(
+            x, x_shape, y, y_shape, x_init_value=x_init_value)
+      self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
+
+  def testIncompatibleShapes(self):
+    with self.test_session():
+      x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
+      sp_ids = sparse_tensor_lib.SparseTensor(
+          constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
+          constant_op.constant([0, 1, 2], dtypes.int32),
+          constant_op.constant([2, 2], dtypes.int64))
+      sp_weights = sparse_tensor_lib.SparseTensor(
+          constant_op.constant([[0, 0], [0, 1]], dtypes.int64),
+          constant_op.constant([12.0, 5.0], dtypes.float32),
+          constant_op.constant([1, 2], dtypes.int64))
+
+      with self.assertRaises(ValueError):
+        embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
+            x, sp_ids, sp_weights, combiner="mean")
+
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index 61f2b2ad149..8cd37882570 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -600,9 +600,9 @@ class EmbeddingLookupSparseTest(test.TestCase):
     grouped_ignored_weights = self._GroupByBatchEntry(
         np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
 
-    for num_shards, combiner, dtype, ignore_weights, use_aggregation in \
-        itertools.product([1, 5], ["sum", "mean", "sqrtn"],
-            [dtypes.float32, dtypes.float64], [True, False], [True, False]):
+    for num_shards, combiner, dtype, ignore_weights in itertools.product(
+        [1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
+        [True, False]):
 
       with self.test_session():
         p, params, feed_dict = _EmbeddingParams(
@@ -611,8 +611,7 @@ class EmbeddingLookupSparseTest(test.TestCase):
             p,
             sp_ids,
             None if ignore_weights else sp_weights,
-            combiner=combiner,
-            use_aggregation=use_aggregation)
+            combiner=combiner)
 
         self.assertEqual(embedding_sum.get_shape().as_list(),
                          expected_lookup_result_shape)
@@ -640,9 +639,9 @@ class EmbeddingLookupSparseTest(test.TestCase):
     sp_ids, sp_weights, _, _, _ = (
         self._RandomIdsAndWeights(batch_size, vocab_size))
 
-    for num_shards, combiner, dtype, ignore_weights, use_aggregation in \
-        itertools.product([1, 3], ["sum", "mean", "sqrtn"],
-            [dtypes.float32, dtypes.float64], [True, False], [True, False]):
+    for num_shards, combiner, dtype, ignore_weights in itertools.product(
+        [1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
+        [True, False]):
       with self.test_session():
         x, params, _ = _EmbeddingParams(
             num_shards, vocab_size, shape=param_shape, dtype=dtype)
@@ -651,8 +650,7 @@ class EmbeddingLookupSparseTest(test.TestCase):
             x,
             sp_ids,
             None if ignore_weights else sp_weights,
-            combiner=combiner,
-            use_aggregation=use_aggregation)
+            combiner=combiner)
         x_name = [_PName(i) for i in range(num_shards)]
         x_init_value = [params[x_n + ":0"] for x_n in x_name]
         x_shape = [i.shape for i in x_init_value]
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 8116cb648f8..2aeb9ce14d3 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -42,12 +42,8 @@ def _do_gather(params, ids, validate_indices=True, name=None):
 
 
 def embedding_lookup(params, ids, partition_strategy="mod", name=None,
-                     validate_indices=True, max_norm=None,
-                     use_aggregation=False,
-                     weights=None, idx=None, segment_ids=None):
-  """Looks up `ids` in a list of embedding tensors. Note that `use_aggregation`,
-  `weights`, `idx` and `segment_ids` are for internal use, user should not use
-  them.
+                     validate_indices=True, max_norm=None):
+  """Looks up `ids` in a list of embedding tensors.
 
   This function is used to perform parallel lookups on the list of
   tensors in `params`.  It is a generalization of
@@ -98,25 +94,6 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
   """
   if params is None or params == []:  # pylint: disable=g-explicit-bool-comparison
     raise ValueError("Need at least one param")
-  if use_aggregation:
-    if segment_ids is None:
-      raise ValueError("segment_ids must not be None \
-          when use_aggregation is True")
-    if weights is not None:
-      if idx is not None:
-        raise ValueError("idx must be None \
-            when weights is not None and use_aggregation is True")
-      weights.get_shape().assert_is_compatible_with(segment_ids.get_shape())
-    else:
-      if idx is None:
-        raise ValueError("idx must not be None \
-            when weights is None and use_aggregation is True")
-      idx.get_shape().assert_is_compatible_with(segment_ids.get_shape())
-  else:
-    if weights is not None or idx is not None or segment_ids is not None:
-      raise ValueError("weights, idx and segment_ids must be None \
-          when use_aggregation is False")
-
   if isinstance(params, variables.PartitionedVariable):
     params = list(params)  # Iterate to get the underlying Variables.
   if not isinstance(params, list):
@@ -137,31 +114,9 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
       params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
     if np == 1:
       with ops.colocate_with(params[0]):
-        ret = maybe_normalize(
+        return maybe_normalize(
             _do_gather(
                 params[0], ids, validate_indices=validate_indices, name=name))
-        if not use_aggregation:
-          return ret
-        else:
-          ignore_weights = weights is None
-          if not ignore_weights:
-            if weights.dtype != ret.dtype:
-              weights = math_ops.cast(weights, ret.dtype)
-            # Reshape to allow broadcast
-            ones = array_ops.fill(
-                array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1)
-            bcast_weights_shape = array_ops.concat(
-                [array_ops.shape(weights), ones], 0)
-            orig_weights_shape = weights.get_shape()
-            weights = array_ops.reshape(weights, bcast_weights_shape)
-            # Set weights shape after reshape
-            if ret.get_shape().ndims is not None:
-              weights.set_shape(orig_weights_shape.concatenate(
-                  [1 for _ in range(ret.get_shape().ndims - 1)]))
-            ret *= weights
-            return math_ops.segment_sum(ret, segment_ids)
-          else:
-            return math_ops.sparse_segment_sum(ret, idx, segment_ids)
     else:
       ids = ops.convert_to_tensor(ids, name="ids")
       flat_ids = array_ops.reshape(ids, [-1])
@@ -224,131 +179,39 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
           partitioned_result.append(
               _do_gather(params[p], gather_ids[p],
                          validate_indices=validate_indices))
-
-      if not use_aggregation:
-        # Stitch these back together
-        ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
-                                           name=name)
-        # Reshape to reverse the flattening of ids.
-        element_shape = params[0].get_shape()[1:]
-        for p in params[1:]:
-          element_shape = element_shape.merge_with(p.get_shape()[1:])
-        if element_shape.is_fully_defined():
-          ret = array_ops.reshape(ret,
-                                  array_ops.concat(
-                                      [array_ops.shape(ids), element_shape], 0))
-        else:
-          # It's important that we compute params[0].shape on the right device
-          # to avoid data motion.
-          with ops.colocate_with(params[0]):
-            params_shape = array_ops.shape(params[0])
-          ret = array_ops.reshape(ret,
-                                  array_ops.concat([
-                                      array_ops.shape(ids),
-                                      array_ops.slice(params_shape, [1], [-1])
-                                  ], 0))
-        # output shape = ids.shape + params[*].shape[1:]
-        # Normally the reshape is sufficient, but setting shape explicitly
-        # teaches shape inference that params[1:].get_shape() matters.
-        ret.set_shape(ids.get_shape().concatenate(element_shape))
-        return maybe_normalize(ret)
+      # Stitch these back together
+      ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
+                                         name=name)
+      # Reshape to reverse the flattening of ids.
+      element_shape = params[0].get_shape()[1:]
+      for p in params[1:]:
+        element_shape = element_shape.merge_with(p.get_shape()[1:])
+      if element_shape.is_fully_defined():
+        ret = array_ops.reshape(ret,
+                                array_ops.concat(
+                                    [array_ops.shape(ids), element_shape], 0))
       else:
-        # We use distributed aggregation.
-        ignore_weights = weights is None
-        if not ignore_weights:
-          # Partition weights according to pindices.
-          partitioned_weight = []
-          for p in xrange(np):
-            partitioned_weight.append(array_ops.gather(weights, pindices[p]))
-        # Reshape each partition result.
-        element_shape = params[0].get_shape()[1:]
-        for p in params[1:]:
-          element_shape = element_shape.merge_with(p.get_shape()[1:])
-        if element_shape.is_fully_defined():
-          for p in xrange(np):
-            with ops.colocate_with(params[p]):
-              partitioned_result[p] = array_ops.reshape(partitioned_result[p],
-                  array_ops.concat(
-                      [array_ops.shape(pindices[p]), element_shape], 0))
-        else:
-          with ops.colocate_with(params[0]):
-            params_shape = array_ops.shape(params[0])
-          for p in xrange(np):
-            with ops.colocate_with(params[p]):
-              partitioned_result[p] = array_ops.reshape(partitioned_result[p],
-                  array_ops.concat([array_ops.shape(pindices[p]),
-                      array_ops.slice(params_shape, [1], [-1])], 0))
-        # Normalize each partition result.
-        for p in xrange(np):
-          with ops.colocate_with(params[p]):
-            partitioned_result[p] = maybe_normalize(partitioned_result[p])
-        if not ignore_weights:
-          # Multiply each partition result with partition weights.
-          for p in xrange(np):
-            with ops.colocate_with(params[p]):
-              if partitioned_weight[p].dtype != partitioned_result[p].dtype:
-                partitioned_weight[p] = math_ops.cast(partitioned_weight[p],
-                    partitioned_result[p].dtype)
-              # Reshape partition weights.
-              ones = array_ops.fill(
-                  array_ops.expand_dims(
-                      array_ops.rank(partitioned_result[p]) - 1, 0), 1)
-              bcast_weights_shape = array_ops.concat(
-                  [array_ops.shape(partitioned_weight[p]), ones], 0)
-              orig_weights_shape = partitioned_weight[p].get_shape()
-              partitioned_weight[p] = array_ops.reshape(partitioned_weight[p],
-                                                        bcast_weights_shape)
-              if partitioned_result[p].get_shape().ndims is not None:
-                partitioned_weight[p].set_shape(orig_weights_shape.concatenate(
-                    [1 for _ in range(
-                        partitioned_result[p].get_shape().ndims - 1)]))
-              partitioned_result[p] *= partitioned_weight[p]
-        partitioned_segment_ids = []
-        for p in xrange(np):
-          if not ignore_weights:
-            # Partition segment_ids according to pindices.
-            p_segment_ids = array_ops.gather(segment_ids, pindices[p])
-            # Number the p_segment_ids to meet segment_sum's requirements. Note
-            # that unique_p_segment_ids contains unique segment ids of this
-            # partiton and these ids' order is unchanged.
-            unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
-                p_segment_ids)
-            partitioned_segment_ids.append(unique_p_segment_ids)
-            # segment_sum this partition's result.
-            with ops.colocate_with(params[p]):
-              partitioned_result[p] = math_ops.segment_sum(
-                  partitioned_result[p], unique_p_segment_idx)
-          else:
-            # When ignore weights, we need to get indexs of elements in idx and
-            # segment_ids.
-            _, exclude_idx = array_ops.setdiff1d(idx, pindices[p])
-            all_idx = math_ops.range(array_ops.shape(idx)[0])
-            _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx)
-            # Gather segment_ids and idx according to indexs.
-            p_segment_ids = array_ops.gather(segment_ids, include_idx)
-            p_idx = array_ops.gather(idx, include_idx)
-            # Number the p_segment_ids, same as ignore_weights case above.
-            unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
-                p_segment_ids)
-            _, unique_p_idx_idx = array_ops.unique(p_idx)
-            partitioned_segment_ids.append(unique_p_segment_ids)
-            with ops.colocate_with(params[p]):
-              partitioned_result[p] = math_ops.sparse_segment_sum(
-                  partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx)
-        # Concat each partition's segment_ids and result for final segment_sum.
-        concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0)
-        concat_partitioned_result = array_ops.concat(partitioned_result, 0)
-        return math_ops.unsorted_segment_sum(
-            concat_partitioned_result, concat_segment_ids,
-            math_ops.reduce_max(concat_segment_ids) + 1)
+        # It's important that we compute params[0].shape on the right device
+        # to avoid data motion.
+        with ops.colocate_with(params[0]):
+          params_shape = array_ops.shape(params[0])
+        ret = array_ops.reshape(ret,
+                                array_ops.concat([
+                                    array_ops.shape(ids),
+                                    array_ops.slice(params_shape, [1], [-1])
+                                ], 0))
+      # output shape = ids.shape + params[*].shape[1:]
+      # Normally the reshape is sufficient, but setting shape explicitly
+      # teaches shape inference that params[1:].get_shape() matters.
+      ret.set_shape(ids.get_shape().concatenate(element_shape))
+      return maybe_normalize(ret)
 
 
 def embedding_lookup_sparse(params, sp_ids, sp_weights,
                             partition_strategy="mod",
                             name=None,
                             combiner=None,
-                            max_norm=None,
-                            use_aggregation=False):
+                            max_norm=None):
   """Computes embeddings for the given ids and weights.
 
   This op assumes that there is at least one id for each row in the dense tensor
@@ -381,9 +244,6 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights,
       squares of the weights.
     max_norm: If not None, each embedding is normalized to have l2 norm equal
       to max_norm before combining.
-    use_aggregation: If True, embeddings belonging to same param are aggregated
-      on that device first. This option is intented to reduce data transmission
-      and increase concurrency.
 
   Returns:
     A dense tensor representing the combined embeddings for the
@@ -458,89 +318,56 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights,
     else:
       idx = None
 
-    if not use_aggregation:
-      embeddings = embedding_lookup(
-          params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
-      if not ignore_weights:
-        weights = sp_weights.values
-        if weights.dtype != embeddings.dtype:
-          weights = math_ops.cast(weights, embeddings.dtype)
-
-        # Reshape weights to allow broadcast
-        ones = array_ops.fill(
-            array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
-        bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
-                                               0)
-
-        orig_weights_shape = weights.get_shape()
-        weights = array_ops.reshape(weights, bcast_weights_shape)
-
-        # Set the weight shape, since after reshaping to bcast_weights_shape,
-        # the shape becomes None.
-        if embeddings.get_shape().ndims is not None:
-          weights.set_shape(orig_weights_shape.concatenate(
-              [1 for _ in range(embeddings.get_shape().ndims - 1)]))
-
-        embeddings *= weights
-
-        if combiner == "sum":
-          embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
-        elif combiner == "mean":
-          embeddings = math_ops.segment_sum(embeddings, segment_ids)
-          weight_sum = math_ops.segment_sum(weights, segment_ids)
-          embeddings = math_ops.div(embeddings, weight_sum, name=name)
-        elif combiner == "sqrtn":
-          embeddings = math_ops.segment_sum(embeddings, segment_ids)
-          weights_squared = math_ops.pow(weights, 2)
-          weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
-          weight_sum_sqrt = math_ops.sqrt(weight_sum)
-          embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name)
-        else:
-          assert False, "Unrecognized combiner"
-      else:
-        assert idx is not None
-        if combiner == "sum":
-          embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
-                                                   name=name)
-        elif combiner == "mean":
-          embeddings = math_ops.sparse_segment_mean(embeddings, idx,
-                                                    segment_ids, name=name)
-        elif combiner == "sqrtn":
-          embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx,
-                                                      segment_ids, name=name)
-        else:
-          assert False, "Unrecognized combiner"
-    else:
-      weights = None if ignore_weights else sp_weights.values
-      embeddings = embedding_lookup(
-          params, ids, partition_strategy=partition_strategy, max_norm=max_norm,
-          use_aggregation=True,
-          weights=weights, idx=idx, segment_ids=segment_ids)
-      # Set weights to all one if ignore weights.
-      if ignore_weights:
-        weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1)
+    embeddings = embedding_lookup(
+        params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
+    if not ignore_weights:
+      weights = sp_weights.values
       if weights.dtype != embeddings.dtype:
         weights = math_ops.cast(weights, embeddings.dtype)
-      # Reshape weights.
+
+      # Reshape weights to allow broadcast
       ones = array_ops.fill(
           array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
       bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
                                              0)
+
       orig_weights_shape = weights.get_shape()
       weights = array_ops.reshape(weights, bcast_weights_shape)
+
+      # Set the weight shape, since after reshaping to bcast_weights_shape,
+      # the shape becomes None.
       if embeddings.get_shape().ndims is not None:
         weights.set_shape(orig_weights_shape.concatenate(
             [1 for _ in range(embeddings.get_shape().ndims - 1)]))
 
-      if combiner == "mean":
+      embeddings *= weights
+
+      if combiner == "sum":
+        embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
+      elif combiner == "mean":
+        embeddings = math_ops.segment_sum(embeddings, segment_ids)
         weight_sum = math_ops.segment_sum(weights, segment_ids)
-        embeddings = math_ops.div(embeddings, weight_sum)
+        embeddings = math_ops.div(embeddings, weight_sum, name=name)
       elif combiner == "sqrtn":
+        embeddings = math_ops.segment_sum(embeddings, segment_ids)
         weights_squared = math_ops.pow(weights, 2)
         weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
         weight_sum_sqrt = math_ops.sqrt(weight_sum)
-        embeddings = math_ops.div(embeddings, weight_sum_sqrt)
-      elif combiner != "sum":
+        embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name)
+      else:
+        assert False, "Unrecognized combiner"
+    else:
+      assert idx is not None
+      if combiner == "sum":
+        embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
+                                                 name=name)
+      elif combiner == "mean":
+        embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
+                                                  name=name)
+      elif combiner == "sqrtn":
+        embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx,
+                                                    segment_ids, name=name)
+      else:
         assert False, "Unrecognized combiner"
 
     return embeddings