diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index f0ed31d1d1d..e42e885364c 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from six.moves import xrange # pylint: disable=redefined-builtin + from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op @@ -26,15 +28,18 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging __all__ = [ "safe_embedding_lookup_sparse", "scattered_embedding_lookup", - "scattered_embedding_lookup_sparse", "embedding_lookup_unique" + "scattered_embedding_lookup_sparse", "embedding_lookup_unique", + "embedding_lookup_sparse_with_distributed_aggregation" ] @@ -548,3 +553,326 @@ def _sampled_scattered_embedding_lookup_sparse(params, return math_ops.unsorted_segment_sum(embeddings, segment_ids, num_segments=num_segments, name=name_scope) + + +def embedding_lookup_sparse_with_distributed_aggregation(params, sp_ids, + sp_weights, partition_strategy="mod", name=None, combiner=None, + max_norm=None): + """Computes embeddings for the given ids and weights. + + Embeddings belonging to same param are aggregated on that device first. This + op is intended to decrease data transmission and improve parallelism. See + `tf.nn.embedding_lookup_sparse` for the functionality and example of this op. + + Args: + params: A single tensor representing the complete embedding tensor, + or a list of P tensors all of same shape except for the first dimension, + representing sharded embedding tensors. Alternatively, a + `PartitionedVariable`, created by partitioning along dimension 0. Each + element must be appropriately sized for the given `partition_strategy`. + sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId), + where N is typically batch size and M is arbitrary. + sp_weights: either a SparseTensor of float / double weights, or None to + indicate all weights should be taken to be 1. If specified, sp_weights + must have exactly the same shape and indices as sp_ids. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default + is `"mod"`. See `tf.nn.embedding_lookup` for more details. + name: Optional name for the op. + combiner: A string specifying the reduction op. Currently "mean", "sqrtn" + and "sum" are supported. + "sum" computes the weighted sum of the embedding results for each row. + "mean" is the weighted sum divided by the total weight. + "sqrtn" is the weighted sum divided by the square root of the sum of the + squares of the weights. + max_norm: If not None, each embedding is normalized to have l2 norm equal + to max_norm before combining. + + Returns: + A dense tensor representing the combined embeddings for the + sparse ids. For each row in the dense tensor represented by sp_ids, the op + looks up the embeddings for all ids in that row, multiplies them by the + corresponding weight, and combines these embeddings as specified. + + Raises: + TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither + None nor SparseTensor. + ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}. + """ + if combiner is None: + logging.warn("The default value of combiner will change from \"mean\" " + "to \"sqrtn\" after 2016/11/01.") + combiner = "mean" + if combiner not in ("mean", "sqrtn", "sum"): + raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'") + if isinstance(params, variables.PartitionedVariable): + params = list(params) # Iterate to get the underlying Variables. + if not isinstance(params, list): + params = [params] + if not isinstance(sp_ids, sparse_tensor.SparseTensor): + raise TypeError("sp_ids must be SparseTensor") + ignore_weights = sp_weights is None + if not ignore_weights: + if not isinstance(sp_weights, sparse_tensor.SparseTensor): + raise TypeError("sp_weights must be either None or SparseTensor") + sp_ids.values.get_shape().assert_is_compatible_with( + sp_weights.values.get_shape()) + sp_ids.indices.get_shape().assert_is_compatible_with( + sp_weights.indices.get_shape()) + sp_ids.dense_shape.get_shape().assert_is_compatible_with( + sp_weights.dense_shape.get_shape()) + # TODO(yleon): Add enhanced node assertions to verify that sp_ids and + # sp_weights have equal indices and shapes. + + with ops.name_scope(name, "embedding_lookup_sparse", + params + [sp_ids]) as name: + segment_ids = sp_ids.indices[:, 0] + if segment_ids.dtype != dtypes.int32: + segment_ids = math_ops.cast(segment_ids, dtypes.int32) + + ids = sp_ids.values + if ignore_weights: + ids, idx = array_ops.unique(ids) + else: + idx = None + + weights = None if ignore_weights else sp_weights.values + embeddings = _embedding_lookup_with_distributed_aggregation( + params, ids, partition_strategy=partition_strategy, max_norm=max_norm, + weights=weights, idx=idx, segment_ids=segment_ids) + # Set weights to all one if ignore weights. + if ignore_weights: + weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1) + if weights.dtype != embeddings.dtype: + weights = math_ops.cast(weights, embeddings.dtype) + # Reshape weights. + ones = array_ops.fill( + array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) + bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], + 0) + orig_weights_shape = weights.get_shape() + weights = array_ops.reshape(weights, bcast_weights_shape) + if embeddings.get_shape().ndims is not None: + weights.set_shape(orig_weights_shape.concatenate( + [1 for _ in range(embeddings.get_shape().ndims - 1)])) + + if combiner == "mean": + weight_sum = math_ops.segment_sum(weights, segment_ids) + embeddings = math_ops.div(embeddings, weight_sum) + elif combiner == "sqrtn": + weights_squared = math_ops.pow(weights, 2) + weight_sum = math_ops.segment_sum(weights_squared, segment_ids) + weight_sum_sqrt = math_ops.sqrt(weight_sum) + embeddings = math_ops.div(embeddings, weight_sum_sqrt) + elif combiner != "sum": + assert False, "Unrecognized combiner" + return embeddings + + +def _do_gather(params, ids, validate_indices=True, name=None): + """Deals with doing gather differently for resource variables.""" + if isinstance(params, resource_variable_ops.ResourceVariable): + return params.sparse_read(ids, name=name) + return array_ops.gather( + params, ids, name=name, validate_indices=validate_indices) + + +def _embedding_lookup_with_distributed_aggregation(params, ids, + partition_strategy="mod", name=None, validate_indices=True, max_norm=None, + weights=None, idx=None, segment_ids=None): + """ Lookup helper for embedding_lookup_sparse_with_distributed_aggregation.""" + if params is None or params == []: # pylint: disable=g-explicit-bool-comparison + raise ValueError("Need at least one param") + if isinstance(params, variables.PartitionedVariable): + params = list(params) # Iterate to get the underlying Variables. + if not isinstance(params, list): + params = [params] + def maybe_normalize(x): + if max_norm is not None: + if x.get_shape().ndims is not None: + ndims = x.get_shape().ndims + else: + ndims = array_ops.size(array_ops.shape(x)) + return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims))) + return x + with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation", + params + [ids]) as name: + np = len(params) # Number of partitions + # Preserve the resource variable status to avoid accidental dense reads. + if not any(isinstance(p, resource_variable_ops.ResourceVariable) + for p in params): + params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") + if np == 1: + with ops.colocate_with(params[0]): + ret = maybe_normalize( + _do_gather( + params[0], ids, validate_indices=validate_indices)) + ignore_weights = weights is None + if not ignore_weights: + if weights.dtype != ret.dtype: + weights = math_ops.cast(weights, ret.dtype) + # Reshape to allow broadcast + ones = array_ops.fill( + array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1) + bcast_weights_shape = array_ops.concat( + [array_ops.shape(weights), ones], 0) + orig_weights_shape = weights.get_shape() + weights = array_ops.reshape(weights, bcast_weights_shape) + # Set weights shape after reshape + if ret.get_shape().ndims is not None: + weights.set_shape(orig_weights_shape.concatenate( + [1 for _ in range(ret.get_shape().ndims - 1)])) + ret *= weights + return math_ops.segment_sum(ret, segment_ids, name=name) + else: + return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name) + else: + ids = ops.convert_to_tensor(ids, name="ids") + flat_ids = array_ops.reshape(ids, [-1]) + original_indices = math_ops.range(array_ops.size(flat_ids)) + + # Create p_assignments and set new_ids depending on the strategy. + if partition_strategy == "mod": + p_assignments = flat_ids % np + new_ids = flat_ids // np + elif partition_strategy == "div": + # Compute num_total_ids as the sum of dim-0 of params, then assign to + # partitions based on a constant number of ids per partition. Optimize + # if we already know the full shape statically. + dim_0_size = params[0].get_shape()[0] + for p in xrange(1, np): + dim_0_size += params[p].get_shape()[0] + if dim_0_size.value: + num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) + else: + dim_0_sizes = [] + for p in xrange(np): + if params[p].get_shape()[0].value is not None: + dim_0_sizes.append(params[p].get_shape()[0].value) + else: + with ops.colocate_with(params[p]): + dim_0_sizes.append(array_ops.shape(params[p])[0]) + num_total_ids = math_ops.reduce_sum( + math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) + ids_per_partition = num_total_ids // np + extras = num_total_ids % np + + p_assignments = math_ops.maximum( + flat_ids // (ids_per_partition + 1), + (flat_ids - extras) // ids_per_partition) + + # Emulate a conditional using a boolean indicator tensor + is_in_first_extras_partitions = math_ops.cast( + p_assignments < extras, flat_ids.dtype) + new_ids = ( + is_in_first_extras_partitions * ( + flat_ids % (ids_per_partition + 1)) + + (1 - is_in_first_extras_partitions) * ( + (flat_ids - extras) % ids_per_partition)) + else: + raise ValueError("Unrecognized partition strategy: " + + partition_strategy) + + # Cast partition assignments to int32 for use in dynamic_partition. + # There really should not be more than 2^32 partitions. + p_assignments = math_ops.cast(p_assignments, dtypes.int32) + # Partition list of ids based on assignments into np separate lists + gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) + # Similarly, partition the original indices. + pindices = data_flow_ops.dynamic_partition(original_indices, + p_assignments, np) + # Do np separate lookups, finding embeddings for plist[p] in params[p] + partitioned_result = [] + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result.append( + _do_gather(params[p], gather_ids[p], + validate_indices=validate_indices)) + + ignore_weights = weights is None + if not ignore_weights: + # Partition weights according to pindices. + partitioned_weight = [] + for p in xrange(np): + partitioned_weight.append(array_ops.gather(weights, pindices[p])) + # Reshape each partition result. + element_shape = params[0].get_shape()[1:] + for p in params[1:]: + element_shape = element_shape.merge_with(p.get_shape()[1:]) + if element_shape.is_fully_defined(): + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result[p] = array_ops.reshape(partitioned_result[p], + array_ops.concat( + [array_ops.shape(pindices[p]), element_shape], 0)) + else: + with ops.colocate_with(params[0]): + params_shape = array_ops.shape(params[0]) + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result[p] = array_ops.reshape(partitioned_result[p], + array_ops.concat([array_ops.shape(pindices[p]), + array_ops.slice(params_shape, [1], [-1])], 0)) + # Normalize each partition result. + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result[p] = maybe_normalize(partitioned_result[p]) + if not ignore_weights: + # Multiply each partition result with partition weights. + for p in xrange(np): + with ops.colocate_with(params[p]): + if partitioned_weight[p].dtype != partitioned_result[p].dtype: + partitioned_weight[p] = math_ops.cast(partitioned_weight[p], + partitioned_result[p].dtype) + # Reshape partition weights. + ones = array_ops.fill( + array_ops.expand_dims( + array_ops.rank(partitioned_result[p]) - 1, 0), 1) + bcast_weights_shape = array_ops.concat( + [array_ops.shape(partitioned_weight[p]), ones], 0) + orig_weights_shape = partitioned_weight[p].get_shape() + partitioned_weight[p] = array_ops.reshape(partitioned_weight[p], + bcast_weights_shape) + if partitioned_result[p].get_shape().ndims is not None: + partitioned_weight[p].set_shape(orig_weights_shape.concatenate( + [1 for _ in range( + partitioned_result[p].get_shape().ndims - 1)])) + partitioned_result[p] *= partitioned_weight[p] + partitioned_segment_ids = [] + for p in xrange(np): + if not ignore_weights: + # Partition segment_ids according to pindices. + p_segment_ids = array_ops.gather(segment_ids, pindices[p]) + # Number the p_segment_ids to meet segment_sum's requirements. Note + # that unique_p_segment_ids contains unique segment ids of this + # partiton and these ids' order is unchanged. + unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( + p_segment_ids) + partitioned_segment_ids.append(unique_p_segment_ids) + # segment_sum this partition's result. + with ops.colocate_with(params[p]): + partitioned_result[p] = math_ops.segment_sum( + partitioned_result[p], unique_p_segment_idx) + else: + # When ignore weights, we need to get indexs of elements in idx and + # segment_ids. + _, exclude_idx = array_ops.setdiff1d(idx, pindices[p]) + all_idx = math_ops.range(array_ops.shape(idx)[0]) + _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx) + # Gather segment_ids and idx according to indexs. + p_segment_ids = array_ops.gather(segment_ids, include_idx) + p_idx = array_ops.gather(idx, include_idx) + # Number the p_segment_ids, same as ignore_weights case above. + unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( + p_segment_ids) + _, unique_p_idx_idx = array_ops.unique(p_idx) + partitioned_segment_ids.append(unique_p_segment_ids) + with ops.colocate_with(params[p]): + partitioned_result[p] = math_ops.sparse_segment_sum( + partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx) + # Concat each partition's segment_ids and result for final segment_sum. + concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0) + concat_partitioned_result = array_ops.concat(partitioned_result, 0) + return math_ops.unsorted_segment_sum( + concat_partitioned_result, concat_segment_ids, + math_ops.reduce_max(concat_segment_ids) + 1, name=name) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index dfa8067f27a..eb38d70c52c 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -32,9 +32,11 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import init_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test +from tensorflow.python.util import compat class SafeEmbeddingLookupSparseTest(test.TestCase): @@ -563,5 +565,224 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase): self.assertAllClose(result.eval(), result_abc.eval()) +def _PName(param_id): + return "p" + str(param_id) + + +def _EmbeddingParams(num_shards, + vocab_size, + dtype=dtypes.float32, + shape=None, + use_shapeless_placeholder=False): + p = [] + params = {} + feed_dict = {} + if not shape: + shape = [10] + for i in range(num_shards): + shard_shape = [vocab_size // num_shards] + shape + if i < vocab_size % num_shards: # Excess goes evenly on the first shards + shard_shape[0] += 1 + + param_name = _PName(i) + + if use_shapeless_placeholder: + param = array_ops.placeholder(dtype, shape=None, name=param_name) + else: + param = constant_op.constant( + 1.0, shape=shard_shape, dtype=dtype, name=param_name) + p.append(param) + np_type = "f" if dtype == dtypes.float32 else "d" + val = (np.random.rand(*shard_shape).astype(np_type)) + 1 + params[param_name + ":0"] = val + feed_dict[param.name] = val + return p, params, feed_dict + + +def _EmbeddingResult(params, + id_vals, + num_shards, + vocab_size, + partition_strategy="mod", + weight_vals=None): + if weight_vals is None: + weight_vals = np.copy(id_vals) + weight_vals.fill(1) + values = [] + weights = [] + weights_squared = [] + for ids, wts in zip(id_vals, weight_vals): + value_aggregation = None + weight_aggregation = None + squared_weight_aggregation = None + if isinstance(ids, compat.integral_types): + ids = [ids] + wts = [wts] + for i, weight_value in zip(ids, wts): + if partition_strategy == "mod": + val = np.copy(params[_PName(i % num_shards) + ":0"][ + i // num_shards, :]) * weight_value + elif partition_strategy == "div": + ids_per_partition, extras = divmod(vocab_size, num_shards) + threshold = extras * (ids_per_partition + 1) + if i < threshold: + partition = i // (ids_per_partition + 1) + offset = i % (ids_per_partition + 1) + else: + partition = extras + (i - threshold) // ids_per_partition + offset = (i - threshold) % ids_per_partition + val = np.copy(params[_PName(partition) + ":0"][ + offset, :]) * weight_value + else: + assert False + if value_aggregation is None: + assert weight_aggregation is None + assert squared_weight_aggregation is None + value_aggregation = val + weight_aggregation = weight_value + squared_weight_aggregation = weight_value * weight_value + else: + assert weight_aggregation is not None + assert squared_weight_aggregation is not None + value_aggregation += val + weight_aggregation += weight_value + squared_weight_aggregation += weight_value * weight_value + values.append(value_aggregation) + weights.append(weight_aggregation) + weights_squared.append(squared_weight_aggregation) + values = np.array(values).astype(np.float32) + weights = np.array(weights).astype(np.float32) + weights_squared = np.array(weights_squared).astype(np.float32) + return values, weights, weights_squared + + +class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): + + def _RandomIdsAndWeights(self, batch_size, vocab_size): + max_val_per_entry = 6 + vals_per_batch_entry = np.random.randint( + 1, max_val_per_entry, size=batch_size) + num_vals = np.sum(vals_per_batch_entry) + + ids = np.random.randint(vocab_size, size=num_vals) + weights = 1 + np.random.rand(num_vals) + + indices = [] + for batch_entry, num_val in enumerate(vals_per_batch_entry): + for val_index in range(num_val): + indices.append([batch_entry, val_index]) + + shape = [batch_size, max_val_per_entry] + + sp_ids = sparse_tensor_lib.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(ids, dtypes.int32), + constant_op.constant(shape, dtypes.int64)) + sp_weights = sparse_tensor_lib.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(weights, dtypes.float32), + constant_op.constant(shape, dtypes.int64)) + + return sp_ids, sp_weights, ids, weights, vals_per_batch_entry + + def _GroupByBatchEntry(self, vals, vals_per_batch_entry): + grouped_vals = [] + index = 0 + for num_val in vals_per_batch_entry: + grouped_vals.append(list(vals[index:(index + num_val)])) + index += num_val + return grouped_vals + + def testEmbeddingLookupSparse(self): + vocab_size = 13 + batch_size = 10 + param_shape = [2, 5] + expected_lookup_result_shape = [None] + param_shape + + sp_ids, sp_weights, ids, weights, vals_per_batch_entry = ( + self._RandomIdsAndWeights(batch_size, vocab_size)) + + grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry) + grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry) + grouped_ignored_weights = self._GroupByBatchEntry( + np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry) + + for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 5], + ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64], + [True, False]): + + with self.test_session(): + p, params, feed_dict = _EmbeddingParams( + num_shards, vocab_size, shape=param_shape, dtype=dtype) + embedding_sum = \ + embedding_ops.embedding_lookup_sparse_with_distributed_aggregation( + p, + sp_ids, + None if ignore_weights else sp_weights, + combiner=combiner) + + self.assertEqual(embedding_sum.get_shape().as_list(), + expected_lookup_result_shape) + + tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict) + + np_embedding_sum, np_weight_sum, np_weight_sq_sum = _EmbeddingResult( + params, + grouped_ids, + num_shards, + vocab_size, + weight_vals=grouped_ignored_weights if ignore_weights else + grouped_weights) + if combiner == "mean": + np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1)) + if combiner == "sqrtn": + np_embedding_sum /= np.reshape( + np.sqrt(np_weight_sq_sum), (batch_size, 1, 1)) + self.assertAllClose(np_embedding_sum, tf_embedding_sum) + + def testGradientsEmbeddingLookupSparse(self): + vocab_size = 12 + batch_size = 4 + param_shape = [2, 3] + sp_ids, sp_weights, _, _, _ = ( + self._RandomIdsAndWeights(batch_size, vocab_size)) + + for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 3], + ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64], + [True, False]): + with self.test_session(): + x, params, _ = _EmbeddingParams( + num_shards, vocab_size, shape=param_shape, dtype=dtype) + + y = embedding_ops.embedding_lookup_sparse_with_distributed_aggregation( + x, + sp_ids, + None if ignore_weights else sp_weights, + combiner=combiner) + x_name = [_PName(i) for i in range(num_shards)] + x_init_value = [params[x_n + ":0"] for x_n in x_name] + x_shape = [i.shape for i in x_init_value] + y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:]) + err = gradient_checker.compute_gradient_error( + x, x_shape, y, y_shape, x_init_value=x_init_value) + self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3) + + def testIncompatibleShapes(self): + with self.test_session(): + x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32) + sp_ids = sparse_tensor_lib.SparseTensor( + constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64), + constant_op.constant([0, 1, 2], dtypes.int32), + constant_op.constant([2, 2], dtypes.int64)) + sp_weights = sparse_tensor_lib.SparseTensor( + constant_op.constant([[0, 0], [0, 1]], dtypes.int64), + constant_op.constant([12.0, 5.0], dtypes.float32), + constant_op.constant([1, 2], dtypes.int64)) + + with self.assertRaises(ValueError): + embedding_ops.embedding_lookup_sparse_with_distributed_aggregation( + x, sp_ids, sp_weights, combiner="mean") + + if __name__ == "__main__": test.main()