From bc1a11b331673b68529cb4f08b783d8069967bd9 Mon Sep 17 00:00:00 2001 From: Ziming Dong Date: Wed, 22 Mar 2017 23:44:53 -0400 Subject: [PATCH 1/6] add distributed aggregation to embedding_lookup_sparse --- tensorflow/python/ops/embedding_ops.py | 248 ++++++++++++++++++------- 1 file changed, 183 insertions(+), 65 deletions(-) diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 2aeb9ce14d3..ef0fb83278b 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -42,7 +42,8 @@ def _do_gather(params, ids, validate_indices=True, name=None): def embedding_lookup(params, ids, partition_strategy="mod", name=None, - validate_indices=True, max_norm=None): + validate_indices=True, max_norm=None, + use_aggregation=False, weights=None, segment_ids=None): """Looks up `ids` in a list of embedding tensors. This function is used to perform parallel lookups on the list of @@ -94,6 +95,16 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, """ if params is None or params == []: # pylint: disable=g-explicit-bool-comparison raise ValueError("Need at least one param") + if use_aggregation: + if segment_ids is None: + raise ValueError("segment_ids must not be None \ + when use_aggregation is True") + if weights is not None: + weights.get_shape().assert_is_compatible_with(segment_ids.get_shape()) + else: + if weights is not None or segment_ids is not None: + raise ValueError("weights and segment_ids must be None \ + when use_aggregation is False") if isinstance(params, variables.PartitionedVariable): params = list(params) # Iterate to get the underlying Variables. if not isinstance(params, list): @@ -114,9 +125,27 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") if np == 1: with ops.colocate_with(params[0]): - return maybe_normalize( + ret = maybe_normalize( _do_gather( params[0], ids, validate_indices=validate_indices, name=name)) + if not use_aggregation: + return ret + else: + ignore_weights = weights is None + if not ignore_weights: + if weights.dtype != ret.dtype: + weights = math_ops.cast(weights, ret.dtype) + ones = array_ops.fill( + array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1) + bcast_weights_shape = array_ops.concat( + [array_ops.shape(weights), ones], 0) + orig_weights_shape = weights.get_shape() + weights = array_ops.reshape(weights, bcast_weights_shape) + if ret.get_shape().ndims is not None: + weights.set_shape(orig_weights_shape.concatenate( + [1 for _ in range(ret.get_shape().ndims - 1)])) + ret *= weights + return math_ops.segment_sum(ret, segment_ids) else: ids = ops.convert_to_tensor(ids, name="ids") flat_ids = array_ops.reshape(ids, [-1]) @@ -179,39 +208,100 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, partitioned_result.append( _do_gather(params[p], gather_ids[p], validate_indices=validate_indices)) - # Stitch these back together - ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result, - name=name) - # Reshape to reverse the flattening of ids. - element_shape = params[0].get_shape()[1:] - for p in params[1:]: - element_shape = element_shape.merge_with(p.get_shape()[1:]) - if element_shape.is_fully_defined(): - ret = array_ops.reshape(ret, - array_ops.concat( - [array_ops.shape(ids), element_shape], 0)) + if not use_aggregation: + # Stitch these back together + ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result, + name=name) + # Reshape to reverse the flattening of ids. + element_shape = params[0].get_shape()[1:] + for p in params[1:]: + element_shape = element_shape.merge_with(p.get_shape()[1:]) + if element_shape.is_fully_defined(): + ret = array_ops.reshape(ret, + array_ops.concat( + [array_ops.shape(ids), element_shape], 0)) + else: + # It's important that we compute params[0].shape on the right device + # to avoid data motion. + with ops.colocate_with(params[0]): + params_shape = array_ops.shape(params[0]) + ret = array_ops.reshape(ret, + array_ops.concat([ + array_ops.shape(ids), + array_ops.slice(params_shape, [1], [-1]) + ], 0)) + # output shape = ids.shape + params[*].shape[1:] + # Normally the reshape is sufficient, but setting shape explicitly + # teaches shape inference that params[1:].get_shape() matters. + ret.set_shape(ids.get_shape().concatenate(element_shape)) + return maybe_normalize(ret) else: - # It's important that we compute params[0].shape on the right device - # to avoid data motion. - with ops.colocate_with(params[0]): - params_shape = array_ops.shape(params[0]) - ret = array_ops.reshape(ret, - array_ops.concat([ - array_ops.shape(ids), - array_ops.slice(params_shape, [1], [-1]) - ], 0)) - # output shape = ids.shape + params[*].shape[1:] - # Normally the reshape is sufficient, but setting shape explicitly - # teaches shape inference that params[1:].get_shape() matters. - ret.set_shape(ids.get_shape().concatenate(element_shape)) - return maybe_normalize(ret) + ignore_weights = weights is None + if not ignore_weights: + partitioned_weight = [] + for p in xrange(np): + partitioned_weight.append(array_ops.gather(weights, pindices[p])) + element_shape = params[0].get_shape()[1:] + for p in params[1:]: + element_shape = element_shape.merge_with(p.get_shape()[1:]) + if element_shape.is_fully_defined(): + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result[p] = array_ops.reshape(partitioned_result[p], + array_ops.concat( + [array_ops.shape(pindices[p]), element_shape], 0)) + else: + with ops.colocate_with(params[0]): + params_shape = array_ops.shape(params[0]) + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result[p] = array_ops.reshape(partitioned_result[p], + array_ops.concat([array_ops.shape(pindices[p]), + array_ops.slice(params_shape, [1], [-1])], 0)) + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result[p] = maybe_normalize(partitioned_result[p]) + if not ignore_weights: + for p in xrange(np): + with ops.colocate_with(params[p]): + if partitioned_weight[p].dtype != partitioned_result[p].dtype: + partitioned_weight[p] = math_ops.cast(partitioned_weight[p], + partitioned_result[p].dtype) + ones = array_ops.fill( + array_ops.expand_dims( + array_ops.rank(partitioned_result[p]) - 1, 0), 1) + bcast_weights_shape = array_ops.concat( + [array_ops.shape(partitioned_weight[p]), ones], 0) + orig_weights_shape = partitioned_weight[p].get_shape() + partitioned_weight[p] = array_ops.reshape(partitioned_weight[p], + bcast_weights_shape) + if partitioned_result[p].get_shape().ndims is not None: + partitioned_weight[p].set_shape(orig_weights_shape.concatenate( + [1 for _ in range( + partitioned_result[p].get_shape().ndims - 1)])) + partitioned_result[p] *= partitioned_weight[p] + partitioned_segment_id = [] + for p in xrange(np): + segment_id = array_ops.gather(segment_ids, pindices[p]) + unique_segment_id, idx = array_ops.unique(segment_id) + partitioned_segment_id.append(unique_segment_id) + with ops.colocate_with(params[p]): + partitioned_result[p] = math_ops.segment_sum(partitioned_result[p], + idx) + concat_segment_id = array_ops.concat(partitioned_segment_id, 0) + concat_partitioned_result = array_ops.concat(partitioned_result, 0) + concat_partitioned_result.set_shape( + segment_ids.get_shape().concatenate(element_shape)) + return math_ops.unsorted_segment_sum(concat_partitioned_result, + concat_segment_id, math_ops.reduce_max(concat_segment_id) + 1) def embedding_lookup_sparse(params, sp_ids, sp_weights, partition_strategy="mod", name=None, combiner=None, - max_norm=None): + max_norm=None, + use_aggregation=False): """Computes embeddings for the given ids and weights. This op assumes that there is at least one id for each row in the dense tensor @@ -313,61 +403,89 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, segment_ids = math_ops.cast(segment_ids, dtypes.int32) ids = sp_ids.values - if ignore_weights: - ids, idx = array_ops.unique(ids) - else: - idx = None + if not use_aggregation: + if ignore_weights: + ids, idx = array_ops.unique(ids) + else: + idx = None - embeddings = embedding_lookup( - params, ids, partition_strategy=partition_strategy, max_norm=max_norm) - if not ignore_weights: - weights = sp_weights.values + embeddings = embedding_lookup( + params, ids, partition_strategy=partition_strategy, max_norm=max_norm) + if not ignore_weights: + weights = sp_weights.values + if weights.dtype != embeddings.dtype: + weights = math_ops.cast(weights, embeddings.dtype) + + # Reshape weights to allow broadcast + ones = array_ops.fill( + array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) + bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], + 0) + + orig_weights_shape = weights.get_shape() + weights = array_ops.reshape(weights, bcast_weights_shape) + + # Set the weight shape, since after reshaping to bcast_weights_shape, + # the shape becomes None. + if embeddings.get_shape().ndims is not None: + weights.set_shape(orig_weights_shape.concatenate( + [1 for _ in range(embeddings.get_shape().ndims - 1)])) + + embeddings *= weights + + if combiner == "sum": + embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name) + elif combiner == "mean": + embeddings = math_ops.segment_sum(embeddings, segment_ids) + weight_sum = math_ops.segment_sum(weights, segment_ids) + embeddings = math_ops.div(embeddings, weight_sum, name=name) + elif combiner == "sqrtn": + embeddings = math_ops.segment_sum(embeddings, segment_ids) + weights_squared = math_ops.pow(weights, 2) + weight_sum = math_ops.segment_sum(weights_squared, segment_ids) + weight_sum_sqrt = math_ops.sqrt(weight_sum) + embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name) + else: + assert False, "Unrecognized combiner" + else: + assert idx is not None + if combiner == "sum": + embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids, + name=name) + elif combiner == "mean": + embeddings = math_ops.sparse_segment_mean(embeddings, idx, + segment_ids, name=name) + elif combiner == "sqrtn": + embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, + segment_ids, name=name) + else: + assert False, "Unrecognized combiner" + else: + weights = None if ignore_weights else sp_weights.values + embeddings = embedding_lookup( + params, ids, partition_strategy=partition_strategy, max_norm=max_norm, + use_aggregation=True, weights=weights, segment_ids=segment_ids) + if ignore_weights: + weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1) if weights.dtype != embeddings.dtype: weights = math_ops.cast(weights, embeddings.dtype) - - # Reshape weights to allow broadcast ones = array_ops.fill( array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0) - orig_weights_shape = weights.get_shape() weights = array_ops.reshape(weights, bcast_weights_shape) - - # Set the weight shape, since after reshaping to bcast_weights_shape, - # the shape becomes None. if embeddings.get_shape().ndims is not None: weights.set_shape(orig_weights_shape.concatenate( [1 for _ in range(embeddings.get_shape().ndims - 1)])) - - embeddings *= weights - - if combiner == "sum": - embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name) - elif combiner == "mean": - embeddings = math_ops.segment_sum(embeddings, segment_ids) + if combiner == "mean": weight_sum = math_ops.segment_sum(weights, segment_ids) embeddings = math_ops.div(embeddings, weight_sum, name=name) elif combiner == "sqrtn": - embeddings = math_ops.segment_sum(embeddings, segment_ids) weights_squared = math_ops.pow(weights, 2) weight_sum = math_ops.segment_sum(weights_squared, segment_ids) weight_sum_sqrt = math_ops.sqrt(weight_sum) embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name) - else: + elif combiner != "sum": assert False, "Unrecognized combiner" - else: - assert idx is not None - if combiner == "sum": - embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids, - name=name) - elif combiner == "mean": - embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids, - name=name) - elif combiner == "sqrtn": - embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, - segment_ids, name=name) - else: - assert False, "Unrecognized combiner" - return embeddings From a9bbe754cf5f31d14b907069ea256d2d9c977d5a Mon Sep 17 00:00:00 2001 From: Ziming Dong Date: Thu, 23 Mar 2017 02:51:56 -0400 Subject: [PATCH 2/6] reduce memory cost of aggregation when weights is None --- tensorflow/python/ops/embedding_ops.py | 99 ++++++++++++++++++++------ 1 file changed, 76 insertions(+), 23 deletions(-) diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index ef0fb83278b..4d5ad3e2405 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -43,7 +43,8 @@ def _do_gather(params, ids, validate_indices=True, name=None): def embedding_lookup(params, ids, partition_strategy="mod", name=None, validate_indices=True, max_norm=None, - use_aggregation=False, weights=None, segment_ids=None): + use_aggregation=False, + weights=None, idx=None, segment_ids=None): """Looks up `ids` in a list of embedding tensors. This function is used to perform parallel lookups on the list of @@ -100,11 +101,20 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, raise ValueError("segment_ids must not be None \ when use_aggregation is True") if weights is not None: + if idx is not None: + raise ValueError("idx must be None \ + when weights is not None and use_aggregation is True") weights.get_shape().assert_is_compatible_with(segment_ids.get_shape()) + else: + if idx is None: + raise ValueError("idx must not be None \ + when weights is None and use_aggregation is True") + idx.get_shape().assert_is_compatible_with(segment_ids.get_shape()) else: - if weights is not None or segment_ids is not None: - raise ValueError("weights and segment_ids must be None \ + if weights is not None or idx is not None or segment_ids is not None: + raise ValueError("weights, idx and segment_ids must be None \ when use_aggregation is False") + if isinstance(params, variables.PartitionedVariable): params = list(params) # Iterate to get the underlying Variables. if not isinstance(params, list): @@ -135,17 +145,21 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, if not ignore_weights: if weights.dtype != ret.dtype: weights = math_ops.cast(weights, ret.dtype) + # Reshape to allow broadcast ones = array_ops.fill( array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1) bcast_weights_shape = array_ops.concat( [array_ops.shape(weights), ones], 0) orig_weights_shape = weights.get_shape() weights = array_ops.reshape(weights, bcast_weights_shape) + # Set weights shape after reshape if ret.get_shape().ndims is not None: weights.set_shape(orig_weights_shape.concatenate( [1 for _ in range(ret.get_shape().ndims - 1)])) ret *= weights - return math_ops.segment_sum(ret, segment_ids) + return math_ops.segment_sum(ret, segment_ids) + else: + return math_ops.sparse_segment_sum(ret, idx, segment_ids) else: ids = ops.convert_to_tensor(ids, name="ids") flat_ids = array_ops.reshape(ids, [-1]) @@ -208,6 +222,7 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, partitioned_result.append( _do_gather(params[p], gather_ids[p], validate_indices=validate_indices)) + if not use_aggregation: # Stitch these back together ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result, @@ -236,11 +251,14 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, ret.set_shape(ids.get_shape().concatenate(element_shape)) return maybe_normalize(ret) else: + # We use distributed aggregation. ignore_weights = weights is None if not ignore_weights: + # Partition weights according to pindices. partitioned_weight = [] for p in xrange(np): partitioned_weight.append(array_ops.gather(weights, pindices[p])) + # Reshape each partition result. element_shape = params[0].get_shape()[1:] for p in params[1:]: element_shape = element_shape.merge_with(p.get_shape()[1:]) @@ -258,15 +276,18 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, partitioned_result[p] = array_ops.reshape(partitioned_result[p], array_ops.concat([array_ops.shape(pindices[p]), array_ops.slice(params_shape, [1], [-1])], 0)) + # Normalize each partition result. for p in xrange(np): with ops.colocate_with(params[p]): partitioned_result[p] = maybe_normalize(partitioned_result[p]) if not ignore_weights: + # Multiply each partition result with partition weights. for p in xrange(np): with ops.colocate_with(params[p]): if partitioned_weight[p].dtype != partitioned_result[p].dtype: partitioned_weight[p] = math_ops.cast(partitioned_weight[p], partitioned_result[p].dtype) + # Reshape partition weights. ones = array_ops.fill( array_ops.expand_dims( array_ops.rank(partitioned_result[p]) - 1, 0), 1) @@ -280,20 +301,44 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, [1 for _ in range( partitioned_result[p].get_shape().ndims - 1)])) partitioned_result[p] *= partitioned_weight[p] - partitioned_segment_id = [] + partitioned_segment_ids = [] for p in xrange(np): - segment_id = array_ops.gather(segment_ids, pindices[p]) - unique_segment_id, idx = array_ops.unique(segment_id) - partitioned_segment_id.append(unique_segment_id) - with ops.colocate_with(params[p]): - partitioned_result[p] = math_ops.segment_sum(partitioned_result[p], - idx) - concat_segment_id = array_ops.concat(partitioned_segment_id, 0) + if not ignore_weights: + # Partition segment_ids according to pindices. + p_segment_ids = array_ops.gather(segment_ids, pindices[p]) + # Number the p_segment_ids to meet segment_sum's requirements. Note + # that unique_p_segment_ids contains unique segment ids of this + # partiton and these ids' order is unchanged. + unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( + p_segment_ids) + partitioned_segment_ids.append(unique_p_segment_ids) + # segment_sum this partition's result. + with ops.colocate_with(params[p]): + partitioned_result[p] = math_ops.segment_sum( + partitioned_result[p], unique_p_segment_idx) + else: + # When ignore weights, we need to get indexs of elements in idx and + # segment_ids. + _, exclude_idx = array_ops.setdiff1d(idx, pindices[p]) + all_idx = math_ops.range(array_ops.shape(idx)[0]) + _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx) + # Gather segment_ids and idx according to indexs. + p_segment_ids = array_ops.gather(segment_ids, include_idx) + p_idx = array_ops.gather(idx, include_idx) + # Number the p_segment_ids, same as ignore_weights case above. + unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( + p_segment_ids) + _, unique_p_idx_idx = array_ops.unique(p_idx) + partitioned_segment_ids.append(unique_p_segment_ids) + with ops.colocate_with(params[p]): + partitioned_result[p] = math_ops.sparse_segment_sum( + partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx) + # Concat each partition's segment_ids and result for final segment_sum. + concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0) concat_partitioned_result = array_ops.concat(partitioned_result, 0) - concat_partitioned_result.set_shape( - segment_ids.get_shape().concatenate(element_shape)) - return math_ops.unsorted_segment_sum(concat_partitioned_result, - concat_segment_id, math_ops.reduce_max(concat_segment_id) + 1) + return math_ops.unsorted_segment_sum( + concat_partitioned_result, concat_segment_ids, + math_ops.reduce_max(concat_segment_ids) + 1) def embedding_lookup_sparse(params, sp_ids, sp_weights, @@ -301,7 +346,7 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, name=None, combiner=None, max_norm=None, - use_aggregation=False): + use_aggregation=True): """Computes embeddings for the given ids and weights. This op assumes that there is at least one id for each row in the dense tensor @@ -334,6 +379,9 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, squares of the weights. max_norm: If not None, each embedding is normalized to have l2 norm equal to max_norm before combining. + use_aggregation: If True, embeddings belonging to same param are aggregated + on that device first. This option is intented to reduce data transmission + and increase concurrency. Returns: A dense tensor representing the combined embeddings for the @@ -403,12 +451,12 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, segment_ids = math_ops.cast(segment_ids, dtypes.int32) ids = sp_ids.values - if not use_aggregation: - if ignore_weights: - ids, idx = array_ops.unique(ids) - else: - idx = None + if ignore_weights: + ids, idx = array_ops.unique(ids) + else: + idx = None + if not use_aggregation: embeddings = embedding_lookup( params, ids, partition_strategy=partition_strategy, max_norm=max_norm) if not ignore_weights: @@ -464,11 +512,14 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, weights = None if ignore_weights else sp_weights.values embeddings = embedding_lookup( params, ids, partition_strategy=partition_strategy, max_norm=max_norm, - use_aggregation=True, weights=weights, segment_ids=segment_ids) + use_aggregation=True, + weights=weights, idx=idx, segment_ids=segment_ids) + # Set weights to all one if ignore weights. if ignore_weights: weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1) if weights.dtype != embeddings.dtype: weights = math_ops.cast(weights, embeddings.dtype) + # Reshape weights. ones = array_ops.fill( array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], @@ -478,6 +529,7 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, if embeddings.get_shape().ndims is not None: weights.set_shape(orig_weights_shape.concatenate( [1 for _ in range(embeddings.get_shape().ndims - 1)])) + if combiner == "mean": weight_sum = math_ops.segment_sum(weights, segment_ids) embeddings = math_ops.div(embeddings, weight_sum, name=name) @@ -488,4 +540,5 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name) elif combiner != "sum": assert False, "Unrecognized combiner" + return embeddings From fa751ecd3a78b25386a069a0a7807b705df8e32f Mon Sep 17 00:00:00 2001 From: Ziming Dong Date: Thu, 23 Mar 2017 03:32:43 -0400 Subject: [PATCH 3/6] modify unit test --- tensorflow/python/kernel_tests/embedding_ops_test.py | 12 ++++++------ tensorflow/python/ops/embedding_ops.py | 10 ++++++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index 8cd37882570..255c7e152d4 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -600,9 +600,9 @@ class EmbeddingLookupSparseTest(test.TestCase): grouped_ignored_weights = self._GroupByBatchEntry( np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry) - for num_shards, combiner, dtype, ignore_weights in itertools.product( - [1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64], - [True, False]): + for num_shards, combiner, dtype, ignore_weights, use_aggregation in \ + itertools.product([1, 5], ["sum", "mean", "sqrtn"], + [dtypes.float32, dtypes.float64], [True, False], [True, False]): with self.test_session(): p, params, feed_dict = _EmbeddingParams( @@ -639,9 +639,9 @@ class EmbeddingLookupSparseTest(test.TestCase): sp_ids, sp_weights, _, _, _ = ( self._RandomIdsAndWeights(batch_size, vocab_size)) - for num_shards, combiner, dtype, ignore_weights in itertools.product( - [1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64], - [True, False]): + for num_shards, combiner, dtype, ignore_weights, use_aggregation in \ + itertools.product([1, 3], ["sum", "mean", "sqrtn"], + [dtypes.float32, dtypes.float64], [True, False], [True, False]): with self.test_session(): x, params, _ = _EmbeddingParams( num_shards, vocab_size, shape=param_shape, dtype=dtype) diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 4d5ad3e2405..8116cb648f8 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -45,7 +45,9 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, validate_indices=True, max_norm=None, use_aggregation=False, weights=None, idx=None, segment_ids=None): - """Looks up `ids` in a list of embedding tensors. + """Looks up `ids` in a list of embedding tensors. Note that `use_aggregation`, + `weights`, `idx` and `segment_ids` are for internal use, user should not use + them. This function is used to perform parallel lookups on the list of tensors in `params`. It is a generalization of @@ -346,7 +348,7 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, name=None, combiner=None, max_norm=None, - use_aggregation=True): + use_aggregation=False): """Computes embeddings for the given ids and weights. This op assumes that there is at least one id for each row in the dense tensor @@ -532,12 +534,12 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, if combiner == "mean": weight_sum = math_ops.segment_sum(weights, segment_ids) - embeddings = math_ops.div(embeddings, weight_sum, name=name) + embeddings = math_ops.div(embeddings, weight_sum) elif combiner == "sqrtn": weights_squared = math_ops.pow(weights, 2) weight_sum = math_ops.segment_sum(weights_squared, segment_ids) weight_sum_sqrt = math_ops.sqrt(weight_sum) - embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name) + embeddings = math_ops.div(embeddings, weight_sum_sqrt) elif combiner != "sum": assert False, "Unrecognized combiner" From f00b5c9b2e1c3c02ff5ebe49e3a95cc28d15946c Mon Sep 17 00:00:00 2001 From: Ziming Dong Date: Thu, 23 Mar 2017 03:43:46 -0400 Subject: [PATCH 4/6] fix unit test bug --- tensorflow/python/kernel_tests/embedding_ops_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index 255c7e152d4..61f2b2ad149 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -611,7 +611,8 @@ class EmbeddingLookupSparseTest(test.TestCase): p, sp_ids, None if ignore_weights else sp_weights, - combiner=combiner) + combiner=combiner, + use_aggregation=use_aggregation) self.assertEqual(embedding_sum.get_shape().as_list(), expected_lookup_result_shape) @@ -650,7 +651,8 @@ class EmbeddingLookupSparseTest(test.TestCase): x, sp_ids, None if ignore_weights else sp_weights, - combiner=combiner) + combiner=combiner, + use_aggregation=use_aggregation) x_name = [_PName(i) for i in range(num_shards)] x_init_value = [params[x_n + ":0"] for x_n in x_name] x_shape = [i.shape for i in x_init_value] From 71821efb545b0f467214c3647c2645d0b44deedc Mon Sep 17 00:00:00 2001 From: Ziming Dong Date: Wed, 29 Mar 2017 12:46:44 -0400 Subject: [PATCH 5/6] move code to contrib --- .../layers/python/layers/embedding_ops.py | 328 +++++++++++++++++- .../python/layers/embedding_ops_test.py | 221 ++++++++++++ .../python/kernel_tests/embedding_ops_test.py | 18 +- tensorflow/python/ops/embedding_ops.py | 297 ++++------------ 4 files changed, 618 insertions(+), 246 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index f0ed31d1d1d..e739daa2fcd 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -26,15 +26,18 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging __all__ = [ "safe_embedding_lookup_sparse", "scattered_embedding_lookup", - "scattered_embedding_lookup_sparse", "embedding_lookup_unique" + "scattered_embedding_lookup_sparse", "embedding_lookup_unique", + "embedding_lookup_sparse_with_distributed_aggregation" ] @@ -548,3 +551,326 @@ def _sampled_scattered_embedding_lookup_sparse(params, return math_ops.unsorted_segment_sum(embeddings, segment_ids, num_segments=num_segments, name=name_scope) + + +def embedding_lookup_sparse_with_distributed_aggregation(params, sp_ids, + sp_weights, partition_strategy="mod", name=None, combiner=None, + max_norm=None): + """Computes embeddings for the given ids and weights. + + Embeddings belonging to same param are aggregated on that device first. This + op is intended to decrease data transmission and improve parallelism. See + `tf.nn.embedding_lookup_sparse` for the functionality and example of this op. + + Args: + params: A single tensor representing the complete embedding tensor, + or a list of P tensors all of same shape except for the first dimension, + representing sharded embedding tensors. Alternatively, a + `PartitionedVariable`, created by partitioning along dimension 0. Each + element must be appropriately sized for the given `partition_strategy`. + sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId), + where N is typically batch size and M is arbitrary. + sp_weights: either a SparseTensor of float / double weights, or None to + indicate all weights should be taken to be 1. If specified, sp_weights + must have exactly the same shape and indices as sp_ids. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default + is `"mod"`. See `tf.nn.embedding_lookup` for more details. + name: Optional name for the op. + combiner: A string specifying the reduction op. Currently "mean", "sqrtn" + and "sum" are supported. + "sum" computes the weighted sum of the embedding results for each row. + "mean" is the weighted sum divided by the total weight. + "sqrtn" is the weighted sum divided by the square root of the sum of the + squares of the weights. + max_norm: If not None, each embedding is normalized to have l2 norm equal + to max_norm before combining. + + Returns: + A dense tensor representing the combined embeddings for the + sparse ids. For each row in the dense tensor represented by sp_ids, the op + looks up the embeddings for all ids in that row, multiplies them by the + corresponding weight, and combines these embeddings as specified. + + Raises: + TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither + None nor SparseTensor. + ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}. + """ + if combiner is None: + logging.warn("The default value of combiner will change from \"mean\" " + "to \"sqrtn\" after 2016/11/01.") + combiner = "mean" + if combiner not in ("mean", "sqrtn", "sum"): + raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'") + if isinstance(params, variables.PartitionedVariable): + params = list(params) # Iterate to get the underlying Variables. + if not isinstance(params, list): + params = [params] + if not isinstance(sp_ids, sparse_tensor.SparseTensor): + raise TypeError("sp_ids must be SparseTensor") + ignore_weights = sp_weights is None + if not ignore_weights: + if not isinstance(sp_weights, sparse_tensor.SparseTensor): + raise TypeError("sp_weights must be either None or SparseTensor") + sp_ids.values.get_shape().assert_is_compatible_with( + sp_weights.values.get_shape()) + sp_ids.indices.get_shape().assert_is_compatible_with( + sp_weights.indices.get_shape()) + sp_ids.dense_shape.get_shape().assert_is_compatible_with( + sp_weights.dense_shape.get_shape()) + # TODO(yleon): Add enhanced node assertions to verify that sp_ids and + # sp_weights have equal indices and shapes. + + with ops.name_scope(name, "embedding_lookup_sparse", + params + [sp_ids]) as name: + segment_ids = sp_ids.indices[:, 0] + if segment_ids.dtype != dtypes.int32: + segment_ids = math_ops.cast(segment_ids, dtypes.int32) + + ids = sp_ids.values + if ignore_weights: + ids, idx = array_ops.unique(ids) + else: + idx = None + + weights = None if ignore_weights else sp_weights.values + embeddings = _embedding_lookup_with_distributed_aggregation( + params, ids, partition_strategy=partition_strategy, max_norm=max_norm, + weights=weights, idx=idx, segment_ids=segment_ids) + # Set weights to all one if ignore weights. + if ignore_weights: + weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1) + if weights.dtype != embeddings.dtype: + weights = math_ops.cast(weights, embeddings.dtype) + # Reshape weights. + ones = array_ops.fill( + array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) + bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], + 0) + orig_weights_shape = weights.get_shape() + weights = array_ops.reshape(weights, bcast_weights_shape) + if embeddings.get_shape().ndims is not None: + weights.set_shape(orig_weights_shape.concatenate( + [1 for _ in range(embeddings.get_shape().ndims - 1)])) + + if combiner == "mean": + weight_sum = math_ops.segment_sum(weights, segment_ids) + embeddings = math_ops.div(embeddings, weight_sum) + elif combiner == "sqrtn": + weights_squared = math_ops.pow(weights, 2) + weight_sum = math_ops.segment_sum(weights_squared, segment_ids) + weight_sum_sqrt = math_ops.sqrt(weight_sum) + embeddings = math_ops.div(embeddings, weight_sum_sqrt) + elif combiner != "sum": + assert False, "Unrecognized combiner" + return embeddings + + +def _do_gather(params, ids, validate_indices=True, name=None): + """Deals with doing gather differently for resource variables.""" + if isinstance(params, resource_variable_ops.ResourceVariable): + return params.sparse_read(ids, name=name) + return array_ops.gather( + params, ids, name=name, validate_indices=validate_indices) + + +def _embedding_lookup_with_distributed_aggregation(params, ids, + partition_strategy="mod", name=None, validate_indices=True, max_norm=None, + weights=None, idx=None, segment_ids=None): + """ Lookup helper for embedding_lookup_sparse_with_distributed_aggregation.""" + if params is None or params == []: # pylint: disable=g-explicit-bool-comparison + raise ValueError("Need at least one param") + if isinstance(params, variables.PartitionedVariable): + params = list(params) # Iterate to get the underlying Variables. + if not isinstance(params, list): + params = [params] + def maybe_normalize(x): + if max_norm is not None: + if x.get_shape().ndims is not None: + ndims = x.get_shape().ndims + else: + ndims = array_ops.size(array_ops.shape(x)) + return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims))) + return x + with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation", + params + [ids]) as name: + np = len(params) # Number of partitions + # Preserve the resource variable status to avoid accidental dense reads. + if not any(isinstance(p, resource_variable_ops.ResourceVariable) + for p in params): + params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") + if np == 1: + with ops.colocate_with(params[0]): + ret = maybe_normalize( + _do_gather( + params[0], ids, validate_indices=validate_indices)) + ignore_weights = weights is None + if not ignore_weights: + if weights.dtype != ret.dtype: + weights = math_ops.cast(weights, ret.dtype) + # Reshape to allow broadcast + ones = array_ops.fill( + array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1) + bcast_weights_shape = array_ops.concat( + [array_ops.shape(weights), ones], 0) + orig_weights_shape = weights.get_shape() + weights = array_ops.reshape(weights, bcast_weights_shape) + # Set weights shape after reshape + if ret.get_shape().ndims is not None: + weights.set_shape(orig_weights_shape.concatenate( + [1 for _ in range(ret.get_shape().ndims - 1)])) + ret *= weights + return math_ops.segment_sum(ret, segment_ids, name=name) + else: + return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name) + else: + ids = ops.convert_to_tensor(ids, name="ids") + flat_ids = array_ops.reshape(ids, [-1]) + original_indices = math_ops.range(array_ops.size(flat_ids)) + + # Create p_assignments and set new_ids depending on the strategy. + if partition_strategy == "mod": + p_assignments = flat_ids % np + new_ids = flat_ids // np + elif partition_strategy == "div": + # Compute num_total_ids as the sum of dim-0 of params, then assign to + # partitions based on a constant number of ids per partition. Optimize + # if we already know the full shape statically. + dim_0_size = params[0].get_shape()[0] + for p in xrange(1, np): + dim_0_size += params[p].get_shape()[0] + if dim_0_size.value: + num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) + else: + dim_0_sizes = [] + for p in xrange(np): + if params[p].get_shape()[0].value is not None: + dim_0_sizes.append(params[p].get_shape()[0].value) + else: + with ops.colocate_with(params[p]): + dim_0_sizes.append(array_ops.shape(params[p])[0]) + num_total_ids = math_ops.reduce_sum( + math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) + ids_per_partition = num_total_ids // np + extras = num_total_ids % np + + p_assignments = math_ops.maximum( + flat_ids // (ids_per_partition + 1), + (flat_ids - extras) // ids_per_partition) + + # Emulate a conditional using a boolean indicator tensor + is_in_first_extras_partitions = math_ops.cast( + p_assignments < extras, flat_ids.dtype) + new_ids = ( + is_in_first_extras_partitions * ( + flat_ids % (ids_per_partition + 1)) + + (1 - is_in_first_extras_partitions) * ( + (flat_ids - extras) % ids_per_partition)) + else: + raise ValueError("Unrecognized partition strategy: " + + partition_strategy) + + # Cast partition assignments to int32 for use in dynamic_partition. + # There really should not be more than 2^32 partitions. + p_assignments = math_ops.cast(p_assignments, dtypes.int32) + # Partition list of ids based on assignments into np separate lists + gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) + # Similarly, partition the original indices. + pindices = data_flow_ops.dynamic_partition(original_indices, + p_assignments, np) + # Do np separate lookups, finding embeddings for plist[p] in params[p] + partitioned_result = [] + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result.append( + _do_gather(params[p], gather_ids[p], + validate_indices=validate_indices)) + + ignore_weights = weights is None + if not ignore_weights: + # Partition weights according to pindices. + partitioned_weight = [] + for p in xrange(np): + partitioned_weight.append(array_ops.gather(weights, pindices[p])) + # Reshape each partition result. + element_shape = params[0].get_shape()[1:] + for p in params[1:]: + element_shape = element_shape.merge_with(p.get_shape()[1:]) + if element_shape.is_fully_defined(): + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result[p] = array_ops.reshape(partitioned_result[p], + array_ops.concat( + [array_ops.shape(pindices[p]), element_shape], 0)) + else: + with ops.colocate_with(params[0]): + params_shape = array_ops.shape(params[0]) + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result[p] = array_ops.reshape(partitioned_result[p], + array_ops.concat([array_ops.shape(pindices[p]), + array_ops.slice(params_shape, [1], [-1])], 0)) + # Normalize each partition result. + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result[p] = maybe_normalize(partitioned_result[p]) + if not ignore_weights: + # Multiply each partition result with partition weights. + for p in xrange(np): + with ops.colocate_with(params[p]): + if partitioned_weight[p].dtype != partitioned_result[p].dtype: + partitioned_weight[p] = math_ops.cast(partitioned_weight[p], + partitioned_result[p].dtype) + # Reshape partition weights. + ones = array_ops.fill( + array_ops.expand_dims( + array_ops.rank(partitioned_result[p]) - 1, 0), 1) + bcast_weights_shape = array_ops.concat( + [array_ops.shape(partitioned_weight[p]), ones], 0) + orig_weights_shape = partitioned_weight[p].get_shape() + partitioned_weight[p] = array_ops.reshape(partitioned_weight[p], + bcast_weights_shape) + if partitioned_result[p].get_shape().ndims is not None: + partitioned_weight[p].set_shape(orig_weights_shape.concatenate( + [1 for _ in range( + partitioned_result[p].get_shape().ndims - 1)])) + partitioned_result[p] *= partitioned_weight[p] + partitioned_segment_ids = [] + for p in xrange(np): + if not ignore_weights: + # Partition segment_ids according to pindices. + p_segment_ids = array_ops.gather(segment_ids, pindices[p]) + # Number the p_segment_ids to meet segment_sum's requirements. Note + # that unique_p_segment_ids contains unique segment ids of this + # partiton and these ids' order is unchanged. + unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( + p_segment_ids) + partitioned_segment_ids.append(unique_p_segment_ids) + # segment_sum this partition's result. + with ops.colocate_with(params[p]): + partitioned_result[p] = math_ops.segment_sum( + partitioned_result[p], unique_p_segment_idx) + else: + # When ignore weights, we need to get indexs of elements in idx and + # segment_ids. + _, exclude_idx = array_ops.setdiff1d(idx, pindices[p]) + all_idx = math_ops.range(array_ops.shape(idx)[0]) + _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx) + # Gather segment_ids and idx according to indexs. + p_segment_ids = array_ops.gather(segment_ids, include_idx) + p_idx = array_ops.gather(idx, include_idx) + # Number the p_segment_ids, same as ignore_weights case above. + unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( + p_segment_ids) + _, unique_p_idx_idx = array_ops.unique(p_idx) + partitioned_segment_ids.append(unique_p_segment_ids) + with ops.colocate_with(params[p]): + partitioned_result[p] = math_ops.sparse_segment_sum( + partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx) + # Concat each partition's segment_ids and result for final segment_sum. + concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0) + concat_partitioned_result = array_ops.concat(partitioned_result, 0) + return math_ops.unsorted_segment_sum( + concat_partitioned_result, concat_segment_ids, + math_ops.reduce_max(concat_segment_ids) + 1, name=name) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index dfa8067f27a..eb38d70c52c 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -32,9 +32,11 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import init_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test +from tensorflow.python.util import compat class SafeEmbeddingLookupSparseTest(test.TestCase): @@ -563,5 +565,224 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase): self.assertAllClose(result.eval(), result_abc.eval()) +def _PName(param_id): + return "p" + str(param_id) + + +def _EmbeddingParams(num_shards, + vocab_size, + dtype=dtypes.float32, + shape=None, + use_shapeless_placeholder=False): + p = [] + params = {} + feed_dict = {} + if not shape: + shape = [10] + for i in range(num_shards): + shard_shape = [vocab_size // num_shards] + shape + if i < vocab_size % num_shards: # Excess goes evenly on the first shards + shard_shape[0] += 1 + + param_name = _PName(i) + + if use_shapeless_placeholder: + param = array_ops.placeholder(dtype, shape=None, name=param_name) + else: + param = constant_op.constant( + 1.0, shape=shard_shape, dtype=dtype, name=param_name) + p.append(param) + np_type = "f" if dtype == dtypes.float32 else "d" + val = (np.random.rand(*shard_shape).astype(np_type)) + 1 + params[param_name + ":0"] = val + feed_dict[param.name] = val + return p, params, feed_dict + + +def _EmbeddingResult(params, + id_vals, + num_shards, + vocab_size, + partition_strategy="mod", + weight_vals=None): + if weight_vals is None: + weight_vals = np.copy(id_vals) + weight_vals.fill(1) + values = [] + weights = [] + weights_squared = [] + for ids, wts in zip(id_vals, weight_vals): + value_aggregation = None + weight_aggregation = None + squared_weight_aggregation = None + if isinstance(ids, compat.integral_types): + ids = [ids] + wts = [wts] + for i, weight_value in zip(ids, wts): + if partition_strategy == "mod": + val = np.copy(params[_PName(i % num_shards) + ":0"][ + i // num_shards, :]) * weight_value + elif partition_strategy == "div": + ids_per_partition, extras = divmod(vocab_size, num_shards) + threshold = extras * (ids_per_partition + 1) + if i < threshold: + partition = i // (ids_per_partition + 1) + offset = i % (ids_per_partition + 1) + else: + partition = extras + (i - threshold) // ids_per_partition + offset = (i - threshold) % ids_per_partition + val = np.copy(params[_PName(partition) + ":0"][ + offset, :]) * weight_value + else: + assert False + if value_aggregation is None: + assert weight_aggregation is None + assert squared_weight_aggregation is None + value_aggregation = val + weight_aggregation = weight_value + squared_weight_aggregation = weight_value * weight_value + else: + assert weight_aggregation is not None + assert squared_weight_aggregation is not None + value_aggregation += val + weight_aggregation += weight_value + squared_weight_aggregation += weight_value * weight_value + values.append(value_aggregation) + weights.append(weight_aggregation) + weights_squared.append(squared_weight_aggregation) + values = np.array(values).astype(np.float32) + weights = np.array(weights).astype(np.float32) + weights_squared = np.array(weights_squared).astype(np.float32) + return values, weights, weights_squared + + +class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): + + def _RandomIdsAndWeights(self, batch_size, vocab_size): + max_val_per_entry = 6 + vals_per_batch_entry = np.random.randint( + 1, max_val_per_entry, size=batch_size) + num_vals = np.sum(vals_per_batch_entry) + + ids = np.random.randint(vocab_size, size=num_vals) + weights = 1 + np.random.rand(num_vals) + + indices = [] + for batch_entry, num_val in enumerate(vals_per_batch_entry): + for val_index in range(num_val): + indices.append([batch_entry, val_index]) + + shape = [batch_size, max_val_per_entry] + + sp_ids = sparse_tensor_lib.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(ids, dtypes.int32), + constant_op.constant(shape, dtypes.int64)) + sp_weights = sparse_tensor_lib.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(weights, dtypes.float32), + constant_op.constant(shape, dtypes.int64)) + + return sp_ids, sp_weights, ids, weights, vals_per_batch_entry + + def _GroupByBatchEntry(self, vals, vals_per_batch_entry): + grouped_vals = [] + index = 0 + for num_val in vals_per_batch_entry: + grouped_vals.append(list(vals[index:(index + num_val)])) + index += num_val + return grouped_vals + + def testEmbeddingLookupSparse(self): + vocab_size = 13 + batch_size = 10 + param_shape = [2, 5] + expected_lookup_result_shape = [None] + param_shape + + sp_ids, sp_weights, ids, weights, vals_per_batch_entry = ( + self._RandomIdsAndWeights(batch_size, vocab_size)) + + grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry) + grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry) + grouped_ignored_weights = self._GroupByBatchEntry( + np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry) + + for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 5], + ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64], + [True, False]): + + with self.test_session(): + p, params, feed_dict = _EmbeddingParams( + num_shards, vocab_size, shape=param_shape, dtype=dtype) + embedding_sum = \ + embedding_ops.embedding_lookup_sparse_with_distributed_aggregation( + p, + sp_ids, + None if ignore_weights else sp_weights, + combiner=combiner) + + self.assertEqual(embedding_sum.get_shape().as_list(), + expected_lookup_result_shape) + + tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict) + + np_embedding_sum, np_weight_sum, np_weight_sq_sum = _EmbeddingResult( + params, + grouped_ids, + num_shards, + vocab_size, + weight_vals=grouped_ignored_weights if ignore_weights else + grouped_weights) + if combiner == "mean": + np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1)) + if combiner == "sqrtn": + np_embedding_sum /= np.reshape( + np.sqrt(np_weight_sq_sum), (batch_size, 1, 1)) + self.assertAllClose(np_embedding_sum, tf_embedding_sum) + + def testGradientsEmbeddingLookupSparse(self): + vocab_size = 12 + batch_size = 4 + param_shape = [2, 3] + sp_ids, sp_weights, _, _, _ = ( + self._RandomIdsAndWeights(batch_size, vocab_size)) + + for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 3], + ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64], + [True, False]): + with self.test_session(): + x, params, _ = _EmbeddingParams( + num_shards, vocab_size, shape=param_shape, dtype=dtype) + + y = embedding_ops.embedding_lookup_sparse_with_distributed_aggregation( + x, + sp_ids, + None if ignore_weights else sp_weights, + combiner=combiner) + x_name = [_PName(i) for i in range(num_shards)] + x_init_value = [params[x_n + ":0"] for x_n in x_name] + x_shape = [i.shape for i in x_init_value] + y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:]) + err = gradient_checker.compute_gradient_error( + x, x_shape, y, y_shape, x_init_value=x_init_value) + self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3) + + def testIncompatibleShapes(self): + with self.test_session(): + x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32) + sp_ids = sparse_tensor_lib.SparseTensor( + constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64), + constant_op.constant([0, 1, 2], dtypes.int32), + constant_op.constant([2, 2], dtypes.int64)) + sp_weights = sparse_tensor_lib.SparseTensor( + constant_op.constant([[0, 0], [0, 1]], dtypes.int64), + constant_op.constant([12.0, 5.0], dtypes.float32), + constant_op.constant([1, 2], dtypes.int64)) + + with self.assertRaises(ValueError): + embedding_ops.embedding_lookup_sparse_with_distributed_aggregation( + x, sp_ids, sp_weights, combiner="mean") + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index 61f2b2ad149..8cd37882570 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -600,9 +600,9 @@ class EmbeddingLookupSparseTest(test.TestCase): grouped_ignored_weights = self._GroupByBatchEntry( np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry) - for num_shards, combiner, dtype, ignore_weights, use_aggregation in \ - itertools.product([1, 5], ["sum", "mean", "sqrtn"], - [dtypes.float32, dtypes.float64], [True, False], [True, False]): + for num_shards, combiner, dtype, ignore_weights in itertools.product( + [1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64], + [True, False]): with self.test_session(): p, params, feed_dict = _EmbeddingParams( @@ -611,8 +611,7 @@ class EmbeddingLookupSparseTest(test.TestCase): p, sp_ids, None if ignore_weights else sp_weights, - combiner=combiner, - use_aggregation=use_aggregation) + combiner=combiner) self.assertEqual(embedding_sum.get_shape().as_list(), expected_lookup_result_shape) @@ -640,9 +639,9 @@ class EmbeddingLookupSparseTest(test.TestCase): sp_ids, sp_weights, _, _, _ = ( self._RandomIdsAndWeights(batch_size, vocab_size)) - for num_shards, combiner, dtype, ignore_weights, use_aggregation in \ - itertools.product([1, 3], ["sum", "mean", "sqrtn"], - [dtypes.float32, dtypes.float64], [True, False], [True, False]): + for num_shards, combiner, dtype, ignore_weights in itertools.product( + [1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64], + [True, False]): with self.test_session(): x, params, _ = _EmbeddingParams( num_shards, vocab_size, shape=param_shape, dtype=dtype) @@ -651,8 +650,7 @@ class EmbeddingLookupSparseTest(test.TestCase): x, sp_ids, None if ignore_weights else sp_weights, - combiner=combiner, - use_aggregation=use_aggregation) + combiner=combiner) x_name = [_PName(i) for i in range(num_shards)] x_init_value = [params[x_n + ":0"] for x_n in x_name] x_shape = [i.shape for i in x_init_value] diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 8116cb648f8..2aeb9ce14d3 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -42,12 +42,8 @@ def _do_gather(params, ids, validate_indices=True, name=None): def embedding_lookup(params, ids, partition_strategy="mod", name=None, - validate_indices=True, max_norm=None, - use_aggregation=False, - weights=None, idx=None, segment_ids=None): - """Looks up `ids` in a list of embedding tensors. Note that `use_aggregation`, - `weights`, `idx` and `segment_ids` are for internal use, user should not use - them. + validate_indices=True, max_norm=None): + """Looks up `ids` in a list of embedding tensors. This function is used to perform parallel lookups on the list of tensors in `params`. It is a generalization of @@ -98,25 +94,6 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, """ if params is None or params == []: # pylint: disable=g-explicit-bool-comparison raise ValueError("Need at least one param") - if use_aggregation: - if segment_ids is None: - raise ValueError("segment_ids must not be None \ - when use_aggregation is True") - if weights is not None: - if idx is not None: - raise ValueError("idx must be None \ - when weights is not None and use_aggregation is True") - weights.get_shape().assert_is_compatible_with(segment_ids.get_shape()) - else: - if idx is None: - raise ValueError("idx must not be None \ - when weights is None and use_aggregation is True") - idx.get_shape().assert_is_compatible_with(segment_ids.get_shape()) - else: - if weights is not None or idx is not None or segment_ids is not None: - raise ValueError("weights, idx and segment_ids must be None \ - when use_aggregation is False") - if isinstance(params, variables.PartitionedVariable): params = list(params) # Iterate to get the underlying Variables. if not isinstance(params, list): @@ -137,31 +114,9 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") if np == 1: with ops.colocate_with(params[0]): - ret = maybe_normalize( + return maybe_normalize( _do_gather( params[0], ids, validate_indices=validate_indices, name=name)) - if not use_aggregation: - return ret - else: - ignore_weights = weights is None - if not ignore_weights: - if weights.dtype != ret.dtype: - weights = math_ops.cast(weights, ret.dtype) - # Reshape to allow broadcast - ones = array_ops.fill( - array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1) - bcast_weights_shape = array_ops.concat( - [array_ops.shape(weights), ones], 0) - orig_weights_shape = weights.get_shape() - weights = array_ops.reshape(weights, bcast_weights_shape) - # Set weights shape after reshape - if ret.get_shape().ndims is not None: - weights.set_shape(orig_weights_shape.concatenate( - [1 for _ in range(ret.get_shape().ndims - 1)])) - ret *= weights - return math_ops.segment_sum(ret, segment_ids) - else: - return math_ops.sparse_segment_sum(ret, idx, segment_ids) else: ids = ops.convert_to_tensor(ids, name="ids") flat_ids = array_ops.reshape(ids, [-1]) @@ -224,131 +179,39 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, partitioned_result.append( _do_gather(params[p], gather_ids[p], validate_indices=validate_indices)) - - if not use_aggregation: - # Stitch these back together - ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result, - name=name) - # Reshape to reverse the flattening of ids. - element_shape = params[0].get_shape()[1:] - for p in params[1:]: - element_shape = element_shape.merge_with(p.get_shape()[1:]) - if element_shape.is_fully_defined(): - ret = array_ops.reshape(ret, - array_ops.concat( - [array_ops.shape(ids), element_shape], 0)) - else: - # It's important that we compute params[0].shape on the right device - # to avoid data motion. - with ops.colocate_with(params[0]): - params_shape = array_ops.shape(params[0]) - ret = array_ops.reshape(ret, - array_ops.concat([ - array_ops.shape(ids), - array_ops.slice(params_shape, [1], [-1]) - ], 0)) - # output shape = ids.shape + params[*].shape[1:] - # Normally the reshape is sufficient, but setting shape explicitly - # teaches shape inference that params[1:].get_shape() matters. - ret.set_shape(ids.get_shape().concatenate(element_shape)) - return maybe_normalize(ret) + # Stitch these back together + ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result, + name=name) + # Reshape to reverse the flattening of ids. + element_shape = params[0].get_shape()[1:] + for p in params[1:]: + element_shape = element_shape.merge_with(p.get_shape()[1:]) + if element_shape.is_fully_defined(): + ret = array_ops.reshape(ret, + array_ops.concat( + [array_ops.shape(ids), element_shape], 0)) else: - # We use distributed aggregation. - ignore_weights = weights is None - if not ignore_weights: - # Partition weights according to pindices. - partitioned_weight = [] - for p in xrange(np): - partitioned_weight.append(array_ops.gather(weights, pindices[p])) - # Reshape each partition result. - element_shape = params[0].get_shape()[1:] - for p in params[1:]: - element_shape = element_shape.merge_with(p.get_shape()[1:]) - if element_shape.is_fully_defined(): - for p in xrange(np): - with ops.colocate_with(params[p]): - partitioned_result[p] = array_ops.reshape(partitioned_result[p], - array_ops.concat( - [array_ops.shape(pindices[p]), element_shape], 0)) - else: - with ops.colocate_with(params[0]): - params_shape = array_ops.shape(params[0]) - for p in xrange(np): - with ops.colocate_with(params[p]): - partitioned_result[p] = array_ops.reshape(partitioned_result[p], - array_ops.concat([array_ops.shape(pindices[p]), - array_ops.slice(params_shape, [1], [-1])], 0)) - # Normalize each partition result. - for p in xrange(np): - with ops.colocate_with(params[p]): - partitioned_result[p] = maybe_normalize(partitioned_result[p]) - if not ignore_weights: - # Multiply each partition result with partition weights. - for p in xrange(np): - with ops.colocate_with(params[p]): - if partitioned_weight[p].dtype != partitioned_result[p].dtype: - partitioned_weight[p] = math_ops.cast(partitioned_weight[p], - partitioned_result[p].dtype) - # Reshape partition weights. - ones = array_ops.fill( - array_ops.expand_dims( - array_ops.rank(partitioned_result[p]) - 1, 0), 1) - bcast_weights_shape = array_ops.concat( - [array_ops.shape(partitioned_weight[p]), ones], 0) - orig_weights_shape = partitioned_weight[p].get_shape() - partitioned_weight[p] = array_ops.reshape(partitioned_weight[p], - bcast_weights_shape) - if partitioned_result[p].get_shape().ndims is not None: - partitioned_weight[p].set_shape(orig_weights_shape.concatenate( - [1 for _ in range( - partitioned_result[p].get_shape().ndims - 1)])) - partitioned_result[p] *= partitioned_weight[p] - partitioned_segment_ids = [] - for p in xrange(np): - if not ignore_weights: - # Partition segment_ids according to pindices. - p_segment_ids = array_ops.gather(segment_ids, pindices[p]) - # Number the p_segment_ids to meet segment_sum's requirements. Note - # that unique_p_segment_ids contains unique segment ids of this - # partiton and these ids' order is unchanged. - unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( - p_segment_ids) - partitioned_segment_ids.append(unique_p_segment_ids) - # segment_sum this partition's result. - with ops.colocate_with(params[p]): - partitioned_result[p] = math_ops.segment_sum( - partitioned_result[p], unique_p_segment_idx) - else: - # When ignore weights, we need to get indexs of elements in idx and - # segment_ids. - _, exclude_idx = array_ops.setdiff1d(idx, pindices[p]) - all_idx = math_ops.range(array_ops.shape(idx)[0]) - _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx) - # Gather segment_ids and idx according to indexs. - p_segment_ids = array_ops.gather(segment_ids, include_idx) - p_idx = array_ops.gather(idx, include_idx) - # Number the p_segment_ids, same as ignore_weights case above. - unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( - p_segment_ids) - _, unique_p_idx_idx = array_ops.unique(p_idx) - partitioned_segment_ids.append(unique_p_segment_ids) - with ops.colocate_with(params[p]): - partitioned_result[p] = math_ops.sparse_segment_sum( - partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx) - # Concat each partition's segment_ids and result for final segment_sum. - concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0) - concat_partitioned_result = array_ops.concat(partitioned_result, 0) - return math_ops.unsorted_segment_sum( - concat_partitioned_result, concat_segment_ids, - math_ops.reduce_max(concat_segment_ids) + 1) + # It's important that we compute params[0].shape on the right device + # to avoid data motion. + with ops.colocate_with(params[0]): + params_shape = array_ops.shape(params[0]) + ret = array_ops.reshape(ret, + array_ops.concat([ + array_ops.shape(ids), + array_ops.slice(params_shape, [1], [-1]) + ], 0)) + # output shape = ids.shape + params[*].shape[1:] + # Normally the reshape is sufficient, but setting shape explicitly + # teaches shape inference that params[1:].get_shape() matters. + ret.set_shape(ids.get_shape().concatenate(element_shape)) + return maybe_normalize(ret) def embedding_lookup_sparse(params, sp_ids, sp_weights, partition_strategy="mod", name=None, combiner=None, - max_norm=None, - use_aggregation=False): + max_norm=None): """Computes embeddings for the given ids and weights. This op assumes that there is at least one id for each row in the dense tensor @@ -381,9 +244,6 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, squares of the weights. max_norm: If not None, each embedding is normalized to have l2 norm equal to max_norm before combining. - use_aggregation: If True, embeddings belonging to same param are aggregated - on that device first. This option is intented to reduce data transmission - and increase concurrency. Returns: A dense tensor representing the combined embeddings for the @@ -458,89 +318,56 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, else: idx = None - if not use_aggregation: - embeddings = embedding_lookup( - params, ids, partition_strategy=partition_strategy, max_norm=max_norm) - if not ignore_weights: - weights = sp_weights.values - if weights.dtype != embeddings.dtype: - weights = math_ops.cast(weights, embeddings.dtype) - - # Reshape weights to allow broadcast - ones = array_ops.fill( - array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) - bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], - 0) - - orig_weights_shape = weights.get_shape() - weights = array_ops.reshape(weights, bcast_weights_shape) - - # Set the weight shape, since after reshaping to bcast_weights_shape, - # the shape becomes None. - if embeddings.get_shape().ndims is not None: - weights.set_shape(orig_weights_shape.concatenate( - [1 for _ in range(embeddings.get_shape().ndims - 1)])) - - embeddings *= weights - - if combiner == "sum": - embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name) - elif combiner == "mean": - embeddings = math_ops.segment_sum(embeddings, segment_ids) - weight_sum = math_ops.segment_sum(weights, segment_ids) - embeddings = math_ops.div(embeddings, weight_sum, name=name) - elif combiner == "sqrtn": - embeddings = math_ops.segment_sum(embeddings, segment_ids) - weights_squared = math_ops.pow(weights, 2) - weight_sum = math_ops.segment_sum(weights_squared, segment_ids) - weight_sum_sqrt = math_ops.sqrt(weight_sum) - embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name) - else: - assert False, "Unrecognized combiner" - else: - assert idx is not None - if combiner == "sum": - embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids, - name=name) - elif combiner == "mean": - embeddings = math_ops.sparse_segment_mean(embeddings, idx, - segment_ids, name=name) - elif combiner == "sqrtn": - embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, - segment_ids, name=name) - else: - assert False, "Unrecognized combiner" - else: - weights = None if ignore_weights else sp_weights.values - embeddings = embedding_lookup( - params, ids, partition_strategy=partition_strategy, max_norm=max_norm, - use_aggregation=True, - weights=weights, idx=idx, segment_ids=segment_ids) - # Set weights to all one if ignore weights. - if ignore_weights: - weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1) + embeddings = embedding_lookup( + params, ids, partition_strategy=partition_strategy, max_norm=max_norm) + if not ignore_weights: + weights = sp_weights.values if weights.dtype != embeddings.dtype: weights = math_ops.cast(weights, embeddings.dtype) - # Reshape weights. + + # Reshape weights to allow broadcast ones = array_ops.fill( array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0) + orig_weights_shape = weights.get_shape() weights = array_ops.reshape(weights, bcast_weights_shape) + + # Set the weight shape, since after reshaping to bcast_weights_shape, + # the shape becomes None. if embeddings.get_shape().ndims is not None: weights.set_shape(orig_weights_shape.concatenate( [1 for _ in range(embeddings.get_shape().ndims - 1)])) - if combiner == "mean": + embeddings *= weights + + if combiner == "sum": + embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name) + elif combiner == "mean": + embeddings = math_ops.segment_sum(embeddings, segment_ids) weight_sum = math_ops.segment_sum(weights, segment_ids) - embeddings = math_ops.div(embeddings, weight_sum) + embeddings = math_ops.div(embeddings, weight_sum, name=name) elif combiner == "sqrtn": + embeddings = math_ops.segment_sum(embeddings, segment_ids) weights_squared = math_ops.pow(weights, 2) weight_sum = math_ops.segment_sum(weights_squared, segment_ids) weight_sum_sqrt = math_ops.sqrt(weight_sum) - embeddings = math_ops.div(embeddings, weight_sum_sqrt) - elif combiner != "sum": + embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name) + else: + assert False, "Unrecognized combiner" + else: + assert idx is not None + if combiner == "sum": + embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids, + name=name) + elif combiner == "mean": + embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids, + name=name) + elif combiner == "sqrtn": + embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, + segment_ids, name=name) + else: assert False, "Unrecognized combiner" return embeddings From 300eab45906e3b39afcb20c8089f8a0e935aa2fc Mon Sep 17 00:00:00 2001 From: Ziming Dong Date: Thu, 30 Mar 2017 12:44:22 -0400 Subject: [PATCH 6/6] fix python 2 and 3 compatibility bug --- tensorflow/contrib/layers/python/layers/embedding_ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index e739daa2fcd..e42e885364c 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from six.moves import xrange # pylint: disable=redefined-builtin + from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op