move code to contrib

This commit is contained in:
Ziming Dong 2017-03-29 12:46:44 -04:00
parent f00b5c9b2e
commit 71821efb54
4 changed files with 618 additions and 246 deletions
tensorflow

View File

@ -26,15 +26,18 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
__all__ = [
"safe_embedding_lookup_sparse", "scattered_embedding_lookup",
"scattered_embedding_lookup_sparse", "embedding_lookup_unique"
"scattered_embedding_lookup_sparse", "embedding_lookup_unique",
"embedding_lookup_sparse_with_distributed_aggregation"
]
@ -548,3 +551,326 @@ def _sampled_scattered_embedding_lookup_sparse(params,
return math_ops.unsorted_segment_sum(embeddings, segment_ids,
num_segments=num_segments,
name=name_scope)
def embedding_lookup_sparse_with_distributed_aggregation(params, sp_ids,
sp_weights, partition_strategy="mod", name=None, combiner=None,
max_norm=None):
"""Computes embeddings for the given ids and weights.
Embeddings belonging to same param are aggregated on that device first. This
op is intended to decrease data transmission and improve parallelism. See
`tf.nn.embedding_lookup_sparse` for the functionality and example of this op.
Args:
params: A single tensor representing the complete embedding tensor,
or a list of P tensors all of same shape except for the first dimension,
representing sharded embedding tensors. Alternatively, a
`PartitionedVariable`, created by partitioning along dimension 0. Each
element must be appropriately sized for the given `partition_strategy`.
sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
where N is typically batch size and M is arbitrary.
sp_weights: either a SparseTensor of float / double weights, or None to
indicate all weights should be taken to be 1. If specified, sp_weights
must have exactly the same shape and indices as sp_ids.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
is `"mod"`. See `tf.nn.embedding_lookup` for more details.
name: Optional name for the op.
combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
and "sum" are supported.
"sum" computes the weighted sum of the embedding results for each row.
"mean" is the weighted sum divided by the total weight.
"sqrtn" is the weighted sum divided by the square root of the sum of the
squares of the weights.
max_norm: If not None, each embedding is normalized to have l2 norm equal
to max_norm before combining.
Returns:
A dense tensor representing the combined embeddings for the
sparse ids. For each row in the dense tensor represented by sp_ids, the op
looks up the embeddings for all ids in that row, multiplies them by the
corresponding weight, and combines these embeddings as specified.
Raises:
TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
None nor SparseTensor.
ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
"""
if combiner is None:
logging.warn("The default value of combiner will change from \"mean\" "
"to \"sqrtn\" after 2016/11/01.")
combiner = "mean"
if combiner not in ("mean", "sqrtn", "sum"):
raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
if isinstance(params, variables.PartitionedVariable):
params = list(params) # Iterate to get the underlying Variables.
if not isinstance(params, list):
params = [params]
if not isinstance(sp_ids, sparse_tensor.SparseTensor):
raise TypeError("sp_ids must be SparseTensor")
ignore_weights = sp_weights is None
if not ignore_weights:
if not isinstance(sp_weights, sparse_tensor.SparseTensor):
raise TypeError("sp_weights must be either None or SparseTensor")
sp_ids.values.get_shape().assert_is_compatible_with(
sp_weights.values.get_shape())
sp_ids.indices.get_shape().assert_is_compatible_with(
sp_weights.indices.get_shape())
sp_ids.dense_shape.get_shape().assert_is_compatible_with(
sp_weights.dense_shape.get_shape())
# TODO(yleon): Add enhanced node assertions to verify that sp_ids and
# sp_weights have equal indices and shapes.
with ops.name_scope(name, "embedding_lookup_sparse",
params + [sp_ids]) as name:
segment_ids = sp_ids.indices[:, 0]
if segment_ids.dtype != dtypes.int32:
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
ids = sp_ids.values
if ignore_weights:
ids, idx = array_ops.unique(ids)
else:
idx = None
weights = None if ignore_weights else sp_weights.values
embeddings = _embedding_lookup_with_distributed_aggregation(
params, ids, partition_strategy=partition_strategy, max_norm=max_norm,
weights=weights, idx=idx, segment_ids=segment_ids)
# Set weights to all one if ignore weights.
if ignore_weights:
weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1)
if weights.dtype != embeddings.dtype:
weights = math_ops.cast(weights, embeddings.dtype)
# Reshape weights.
ones = array_ops.fill(
array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
0)
orig_weights_shape = weights.get_shape()
weights = array_ops.reshape(weights, bcast_weights_shape)
if embeddings.get_shape().ndims is not None:
weights.set_shape(orig_weights_shape.concatenate(
[1 for _ in range(embeddings.get_shape().ndims - 1)]))
if combiner == "mean":
weight_sum = math_ops.segment_sum(weights, segment_ids)
embeddings = math_ops.div(embeddings, weight_sum)
elif combiner == "sqrtn":
weights_squared = math_ops.pow(weights, 2)
weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
weight_sum_sqrt = math_ops.sqrt(weight_sum)
embeddings = math_ops.div(embeddings, weight_sum_sqrt)
elif combiner != "sum":
assert False, "Unrecognized combiner"
return embeddings
def _do_gather(params, ids, validate_indices=True, name=None):
"""Deals with doing gather differently for resource variables."""
if isinstance(params, resource_variable_ops.ResourceVariable):
return params.sparse_read(ids, name=name)
return array_ops.gather(
params, ids, name=name, validate_indices=validate_indices)
def _embedding_lookup_with_distributed_aggregation(params, ids,
partition_strategy="mod", name=None, validate_indices=True, max_norm=None,
weights=None, idx=None, segment_ids=None):
""" Lookup helper for embedding_lookup_sparse_with_distributed_aggregation."""
if params is None or params == []: # pylint: disable=g-explicit-bool-comparison
raise ValueError("Need at least one param")
if isinstance(params, variables.PartitionedVariable):
params = list(params) # Iterate to get the underlying Variables.
if not isinstance(params, list):
params = [params]
def maybe_normalize(x):
if max_norm is not None:
if x.get_shape().ndims is not None:
ndims = x.get_shape().ndims
else:
ndims = array_ops.size(array_ops.shape(x))
return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
return x
with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation",
params + [ids]) as name:
np = len(params) # Number of partitions
# Preserve the resource variable status to avoid accidental dense reads.
if not any(isinstance(p, resource_variable_ops.ResourceVariable)
for p in params):
params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
if np == 1:
with ops.colocate_with(params[0]):
ret = maybe_normalize(
_do_gather(
params[0], ids, validate_indices=validate_indices))
ignore_weights = weights is None
if not ignore_weights:
if weights.dtype != ret.dtype:
weights = math_ops.cast(weights, ret.dtype)
# Reshape to allow broadcast
ones = array_ops.fill(
array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1)
bcast_weights_shape = array_ops.concat(
[array_ops.shape(weights), ones], 0)
orig_weights_shape = weights.get_shape()
weights = array_ops.reshape(weights, bcast_weights_shape)
# Set weights shape after reshape
if ret.get_shape().ndims is not None:
weights.set_shape(orig_weights_shape.concatenate(
[1 for _ in range(ret.get_shape().ndims - 1)]))
ret *= weights
return math_ops.segment_sum(ret, segment_ids, name=name)
else:
return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name)
else:
ids = ops.convert_to_tensor(ids, name="ids")
flat_ids = array_ops.reshape(ids, [-1])
original_indices = math_ops.range(array_ops.size(flat_ids))
# Create p_assignments and set new_ids depending on the strategy.
if partition_strategy == "mod":
p_assignments = flat_ids % np
new_ids = flat_ids // np
elif partition_strategy == "div":
# Compute num_total_ids as the sum of dim-0 of params, then assign to
# partitions based on a constant number of ids per partition. Optimize
# if we already know the full shape statically.
dim_0_size = params[0].get_shape()[0]
for p in xrange(1, np):
dim_0_size += params[p].get_shape()[0]
if dim_0_size.value:
num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
else:
dim_0_sizes = []
for p in xrange(np):
if params[p].get_shape()[0].value is not None:
dim_0_sizes.append(params[p].get_shape()[0].value)
else:
with ops.colocate_with(params[p]):
dim_0_sizes.append(array_ops.shape(params[p])[0])
num_total_ids = math_ops.reduce_sum(
math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
ids_per_partition = num_total_ids // np
extras = num_total_ids % np
p_assignments = math_ops.maximum(
flat_ids // (ids_per_partition + 1),
(flat_ids - extras) // ids_per_partition)
# Emulate a conditional using a boolean indicator tensor
is_in_first_extras_partitions = math_ops.cast(
p_assignments < extras, flat_ids.dtype)
new_ids = (
is_in_first_extras_partitions * (
flat_ids % (ids_per_partition + 1)) +
(1 - is_in_first_extras_partitions) * (
(flat_ids - extras) % ids_per_partition))
else:
raise ValueError("Unrecognized partition strategy: " +
partition_strategy)
# Cast partition assignments to int32 for use in dynamic_partition.
# There really should not be more than 2^32 partitions.
p_assignments = math_ops.cast(p_assignments, dtypes.int32)
# Partition list of ids based on assignments into np separate lists
gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
# Similarly, partition the original indices.
pindices = data_flow_ops.dynamic_partition(original_indices,
p_assignments, np)
# Do np separate lookups, finding embeddings for plist[p] in params[p]
partitioned_result = []
for p in xrange(np):
with ops.colocate_with(params[p]):
partitioned_result.append(
_do_gather(params[p], gather_ids[p],
validate_indices=validate_indices))
ignore_weights = weights is None
if not ignore_weights:
# Partition weights according to pindices.
partitioned_weight = []
for p in xrange(np):
partitioned_weight.append(array_ops.gather(weights, pindices[p]))
# Reshape each partition result.
element_shape = params[0].get_shape()[1:]
for p in params[1:]:
element_shape = element_shape.merge_with(p.get_shape()[1:])
if element_shape.is_fully_defined():
for p in xrange(np):
with ops.colocate_with(params[p]):
partitioned_result[p] = array_ops.reshape(partitioned_result[p],
array_ops.concat(
[array_ops.shape(pindices[p]), element_shape], 0))
else:
with ops.colocate_with(params[0]):
params_shape = array_ops.shape(params[0])
for p in xrange(np):
with ops.colocate_with(params[p]):
partitioned_result[p] = array_ops.reshape(partitioned_result[p],
array_ops.concat([array_ops.shape(pindices[p]),
array_ops.slice(params_shape, [1], [-1])], 0))
# Normalize each partition result.
for p in xrange(np):
with ops.colocate_with(params[p]):
partitioned_result[p] = maybe_normalize(partitioned_result[p])
if not ignore_weights:
# Multiply each partition result with partition weights.
for p in xrange(np):
with ops.colocate_with(params[p]):
if partitioned_weight[p].dtype != partitioned_result[p].dtype:
partitioned_weight[p] = math_ops.cast(partitioned_weight[p],
partitioned_result[p].dtype)
# Reshape partition weights.
ones = array_ops.fill(
array_ops.expand_dims(
array_ops.rank(partitioned_result[p]) - 1, 0), 1)
bcast_weights_shape = array_ops.concat(
[array_ops.shape(partitioned_weight[p]), ones], 0)
orig_weights_shape = partitioned_weight[p].get_shape()
partitioned_weight[p] = array_ops.reshape(partitioned_weight[p],
bcast_weights_shape)
if partitioned_result[p].get_shape().ndims is not None:
partitioned_weight[p].set_shape(orig_weights_shape.concatenate(
[1 for _ in range(
partitioned_result[p].get_shape().ndims - 1)]))
partitioned_result[p] *= partitioned_weight[p]
partitioned_segment_ids = []
for p in xrange(np):
if not ignore_weights:
# Partition segment_ids according to pindices.
p_segment_ids = array_ops.gather(segment_ids, pindices[p])
# Number the p_segment_ids to meet segment_sum's requirements. Note
# that unique_p_segment_ids contains unique segment ids of this
# partiton and these ids' order is unchanged.
unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
p_segment_ids)
partitioned_segment_ids.append(unique_p_segment_ids)
# segment_sum this partition's result.
with ops.colocate_with(params[p]):
partitioned_result[p] = math_ops.segment_sum(
partitioned_result[p], unique_p_segment_idx)
else:
# When ignore weights, we need to get indexs of elements in idx and
# segment_ids.
_, exclude_idx = array_ops.setdiff1d(idx, pindices[p])
all_idx = math_ops.range(array_ops.shape(idx)[0])
_, include_idx = array_ops.setdiff1d(all_idx, exclude_idx)
# Gather segment_ids and idx according to indexs.
p_segment_ids = array_ops.gather(segment_ids, include_idx)
p_idx = array_ops.gather(idx, include_idx)
# Number the p_segment_ids, same as ignore_weights case above.
unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
p_segment_ids)
_, unique_p_idx_idx = array_ops.unique(p_idx)
partitioned_segment_ids.append(unique_p_segment_ids)
with ops.colocate_with(params[p]):
partitioned_result[p] = math_ops.sparse_segment_sum(
partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx)
# Concat each partition's segment_ids and result for final segment_sum.
concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0)
concat_partitioned_result = array_ops.concat(partitioned_result, 0)
return math_ops.unsorted_segment_sum(
concat_partitioned_result, concat_segment_ids,
math_ops.reduce_max(concat_segment_ids) + 1, name=name)

View File

@ -32,9 +32,11 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class SafeEmbeddingLookupSparseTest(test.TestCase):
@ -563,5 +565,224 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
self.assertAllClose(result.eval(), result_abc.eval())
def _PName(param_id):
return "p" + str(param_id)
def _EmbeddingParams(num_shards,
vocab_size,
dtype=dtypes.float32,
shape=None,
use_shapeless_placeholder=False):
p = []
params = {}
feed_dict = {}
if not shape:
shape = [10]
for i in range(num_shards):
shard_shape = [vocab_size // num_shards] + shape
if i < vocab_size % num_shards: # Excess goes evenly on the first shards
shard_shape[0] += 1
param_name = _PName(i)
if use_shapeless_placeholder:
param = array_ops.placeholder(dtype, shape=None, name=param_name)
else:
param = constant_op.constant(
1.0, shape=shard_shape, dtype=dtype, name=param_name)
p.append(param)
np_type = "f" if dtype == dtypes.float32 else "d"
val = (np.random.rand(*shard_shape).astype(np_type)) + 1
params[param_name + ":0"] = val
feed_dict[param.name] = val
return p, params, feed_dict
def _EmbeddingResult(params,
id_vals,
num_shards,
vocab_size,
partition_strategy="mod",
weight_vals=None):
if weight_vals is None:
weight_vals = np.copy(id_vals)
weight_vals.fill(1)
values = []
weights = []
weights_squared = []
for ids, wts in zip(id_vals, weight_vals):
value_aggregation = None
weight_aggregation = None
squared_weight_aggregation = None
if isinstance(ids, compat.integral_types):
ids = [ids]
wts = [wts]
for i, weight_value in zip(ids, wts):
if partition_strategy == "mod":
val = np.copy(params[_PName(i % num_shards) + ":0"][
i // num_shards, :]) * weight_value
elif partition_strategy == "div":
ids_per_partition, extras = divmod(vocab_size, num_shards)
threshold = extras * (ids_per_partition + 1)
if i < threshold:
partition = i // (ids_per_partition + 1)
offset = i % (ids_per_partition + 1)
else:
partition = extras + (i - threshold) // ids_per_partition
offset = (i - threshold) % ids_per_partition
val = np.copy(params[_PName(partition) + ":0"][
offset, :]) * weight_value
else:
assert False
if value_aggregation is None:
assert weight_aggregation is None
assert squared_weight_aggregation is None
value_aggregation = val
weight_aggregation = weight_value
squared_weight_aggregation = weight_value * weight_value
else:
assert weight_aggregation is not None
assert squared_weight_aggregation is not None
value_aggregation += val
weight_aggregation += weight_value
squared_weight_aggregation += weight_value * weight_value
values.append(value_aggregation)
weights.append(weight_aggregation)
weights_squared.append(squared_weight_aggregation)
values = np.array(values).astype(np.float32)
weights = np.array(weights).astype(np.float32)
weights_squared = np.array(weights_squared).astype(np.float32)
return values, weights, weights_squared
class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
def _RandomIdsAndWeights(self, batch_size, vocab_size):
max_val_per_entry = 6
vals_per_batch_entry = np.random.randint(
1, max_val_per_entry, size=batch_size)
num_vals = np.sum(vals_per_batch_entry)
ids = np.random.randint(vocab_size, size=num_vals)
weights = 1 + np.random.rand(num_vals)
indices = []
for batch_entry, num_val in enumerate(vals_per_batch_entry):
for val_index in range(num_val):
indices.append([batch_entry, val_index])
shape = [batch_size, max_val_per_entry]
sp_ids = sparse_tensor_lib.SparseTensor(
constant_op.constant(indices, dtypes.int64),
constant_op.constant(ids, dtypes.int32),
constant_op.constant(shape, dtypes.int64))
sp_weights = sparse_tensor_lib.SparseTensor(
constant_op.constant(indices, dtypes.int64),
constant_op.constant(weights, dtypes.float32),
constant_op.constant(shape, dtypes.int64))
return sp_ids, sp_weights, ids, weights, vals_per_batch_entry
def _GroupByBatchEntry(self, vals, vals_per_batch_entry):
grouped_vals = []
index = 0
for num_val in vals_per_batch_entry:
grouped_vals.append(list(vals[index:(index + num_val)]))
index += num_val
return grouped_vals
def testEmbeddingLookupSparse(self):
vocab_size = 13
batch_size = 10
param_shape = [2, 5]
expected_lookup_result_shape = [None] + param_shape
sp_ids, sp_weights, ids, weights, vals_per_batch_entry = (
self._RandomIdsAndWeights(batch_size, vocab_size))
grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry)
grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry)
grouped_ignored_weights = self._GroupByBatchEntry(
np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 5],
["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
[True, False]):
with self.test_session():
p, params, feed_dict = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
embedding_sum = \
embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
p,
sp_ids,
None if ignore_weights else sp_weights,
combiner=combiner)
self.assertEqual(embedding_sum.get_shape().as_list(),
expected_lookup_result_shape)
tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
np_embedding_sum, np_weight_sum, np_weight_sq_sum = _EmbeddingResult(
params,
grouped_ids,
num_shards,
vocab_size,
weight_vals=grouped_ignored_weights if ignore_weights else
grouped_weights)
if combiner == "mean":
np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1))
if combiner == "sqrtn":
np_embedding_sum /= np.reshape(
np.sqrt(np_weight_sq_sum), (batch_size, 1, 1))
self.assertAllClose(np_embedding_sum, tf_embedding_sum)
def testGradientsEmbeddingLookupSparse(self):
vocab_size = 12
batch_size = 4
param_shape = [2, 3]
sp_ids, sp_weights, _, _, _ = (
self._RandomIdsAndWeights(batch_size, vocab_size))
for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 3],
["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
[True, False]):
with self.test_session():
x, params, _ = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
y = embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
x,
sp_ids,
None if ignore_weights else sp_weights,
combiner=combiner)
x_name = [_PName(i) for i in range(num_shards)]
x_init_value = [params[x_n + ":0"] for x_n in x_name]
x_shape = [i.shape for i in x_init_value]
y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:])
err = gradient_checker.compute_gradient_error(
x, x_shape, y, y_shape, x_init_value=x_init_value)
self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
def testIncompatibleShapes(self):
with self.test_session():
x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
sp_ids = sparse_tensor_lib.SparseTensor(
constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
constant_op.constant([0, 1, 2], dtypes.int32),
constant_op.constant([2, 2], dtypes.int64))
sp_weights = sparse_tensor_lib.SparseTensor(
constant_op.constant([[0, 0], [0, 1]], dtypes.int64),
constant_op.constant([12.0, 5.0], dtypes.float32),
constant_op.constant([1, 2], dtypes.int64))
with self.assertRaises(ValueError):
embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
x, sp_ids, sp_weights, combiner="mean")
if __name__ == "__main__":
test.main()

View File

@ -600,9 +600,9 @@ class EmbeddingLookupSparseTest(test.TestCase):
grouped_ignored_weights = self._GroupByBatchEntry(
np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
for num_shards, combiner, dtype, ignore_weights, use_aggregation in \
itertools.product([1, 5], ["sum", "mean", "sqrtn"],
[dtypes.float32, dtypes.float64], [True, False], [True, False]):
for num_shards, combiner, dtype, ignore_weights in itertools.product(
[1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
[True, False]):
with self.test_session():
p, params, feed_dict = _EmbeddingParams(
@ -611,8 +611,7 @@ class EmbeddingLookupSparseTest(test.TestCase):
p,
sp_ids,
None if ignore_weights else sp_weights,
combiner=combiner,
use_aggregation=use_aggregation)
combiner=combiner)
self.assertEqual(embedding_sum.get_shape().as_list(),
expected_lookup_result_shape)
@ -640,9 +639,9 @@ class EmbeddingLookupSparseTest(test.TestCase):
sp_ids, sp_weights, _, _, _ = (
self._RandomIdsAndWeights(batch_size, vocab_size))
for num_shards, combiner, dtype, ignore_weights, use_aggregation in \
itertools.product([1, 3], ["sum", "mean", "sqrtn"],
[dtypes.float32, dtypes.float64], [True, False], [True, False]):
for num_shards, combiner, dtype, ignore_weights in itertools.product(
[1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
[True, False]):
with self.test_session():
x, params, _ = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
@ -651,8 +650,7 @@ class EmbeddingLookupSparseTest(test.TestCase):
x,
sp_ids,
None if ignore_weights else sp_weights,
combiner=combiner,
use_aggregation=use_aggregation)
combiner=combiner)
x_name = [_PName(i) for i in range(num_shards)]
x_init_value = [params[x_n + ":0"] for x_n in x_name]
x_shape = [i.shape for i in x_init_value]

View File

@ -42,12 +42,8 @@ def _do_gather(params, ids, validate_indices=True, name=None):
def embedding_lookup(params, ids, partition_strategy="mod", name=None,
validate_indices=True, max_norm=None,
use_aggregation=False,
weights=None, idx=None, segment_ids=None):
"""Looks up `ids` in a list of embedding tensors. Note that `use_aggregation`,
`weights`, `idx` and `segment_ids` are for internal use, user should not use
them.
validate_indices=True, max_norm=None):
"""Looks up `ids` in a list of embedding tensors.
This function is used to perform parallel lookups on the list of
tensors in `params`. It is a generalization of
@ -98,25 +94,6 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
"""
if params is None or params == []: # pylint: disable=g-explicit-bool-comparison
raise ValueError("Need at least one param")
if use_aggregation:
if segment_ids is None:
raise ValueError("segment_ids must not be None \
when use_aggregation is True")
if weights is not None:
if idx is not None:
raise ValueError("idx must be None \
when weights is not None and use_aggregation is True")
weights.get_shape().assert_is_compatible_with(segment_ids.get_shape())
else:
if idx is None:
raise ValueError("idx must not be None \
when weights is None and use_aggregation is True")
idx.get_shape().assert_is_compatible_with(segment_ids.get_shape())
else:
if weights is not None or idx is not None or segment_ids is not None:
raise ValueError("weights, idx and segment_ids must be None \
when use_aggregation is False")
if isinstance(params, variables.PartitionedVariable):
params = list(params) # Iterate to get the underlying Variables.
if not isinstance(params, list):
@ -137,31 +114,9 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
if np == 1:
with ops.colocate_with(params[0]):
ret = maybe_normalize(
return maybe_normalize(
_do_gather(
params[0], ids, validate_indices=validate_indices, name=name))
if not use_aggregation:
return ret
else:
ignore_weights = weights is None
if not ignore_weights:
if weights.dtype != ret.dtype:
weights = math_ops.cast(weights, ret.dtype)
# Reshape to allow broadcast
ones = array_ops.fill(
array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1)
bcast_weights_shape = array_ops.concat(
[array_ops.shape(weights), ones], 0)
orig_weights_shape = weights.get_shape()
weights = array_ops.reshape(weights, bcast_weights_shape)
# Set weights shape after reshape
if ret.get_shape().ndims is not None:
weights.set_shape(orig_weights_shape.concatenate(
[1 for _ in range(ret.get_shape().ndims - 1)]))
ret *= weights
return math_ops.segment_sum(ret, segment_ids)
else:
return math_ops.sparse_segment_sum(ret, idx, segment_ids)
else:
ids = ops.convert_to_tensor(ids, name="ids")
flat_ids = array_ops.reshape(ids, [-1])
@ -224,131 +179,39 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
partitioned_result.append(
_do_gather(params[p], gather_ids[p],
validate_indices=validate_indices))
if not use_aggregation:
# Stitch these back together
ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
name=name)
# Reshape to reverse the flattening of ids.
element_shape = params[0].get_shape()[1:]
for p in params[1:]:
element_shape = element_shape.merge_with(p.get_shape()[1:])
if element_shape.is_fully_defined():
ret = array_ops.reshape(ret,
array_ops.concat(
[array_ops.shape(ids), element_shape], 0))
else:
# It's important that we compute params[0].shape on the right device
# to avoid data motion.
with ops.colocate_with(params[0]):
params_shape = array_ops.shape(params[0])
ret = array_ops.reshape(ret,
array_ops.concat([
array_ops.shape(ids),
array_ops.slice(params_shape, [1], [-1])
], 0))
# output shape = ids.shape + params[*].shape[1:]
# Normally the reshape is sufficient, but setting shape explicitly
# teaches shape inference that params[1:].get_shape() matters.
ret.set_shape(ids.get_shape().concatenate(element_shape))
return maybe_normalize(ret)
# Stitch these back together
ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
name=name)
# Reshape to reverse the flattening of ids.
element_shape = params[0].get_shape()[1:]
for p in params[1:]:
element_shape = element_shape.merge_with(p.get_shape()[1:])
if element_shape.is_fully_defined():
ret = array_ops.reshape(ret,
array_ops.concat(
[array_ops.shape(ids), element_shape], 0))
else:
# We use distributed aggregation.
ignore_weights = weights is None
if not ignore_weights:
# Partition weights according to pindices.
partitioned_weight = []
for p in xrange(np):
partitioned_weight.append(array_ops.gather(weights, pindices[p]))
# Reshape each partition result.
element_shape = params[0].get_shape()[1:]
for p in params[1:]:
element_shape = element_shape.merge_with(p.get_shape()[1:])
if element_shape.is_fully_defined():
for p in xrange(np):
with ops.colocate_with(params[p]):
partitioned_result[p] = array_ops.reshape(partitioned_result[p],
array_ops.concat(
[array_ops.shape(pindices[p]), element_shape], 0))
else:
with ops.colocate_with(params[0]):
params_shape = array_ops.shape(params[0])
for p in xrange(np):
with ops.colocate_with(params[p]):
partitioned_result[p] = array_ops.reshape(partitioned_result[p],
array_ops.concat([array_ops.shape(pindices[p]),
array_ops.slice(params_shape, [1], [-1])], 0))
# Normalize each partition result.
for p in xrange(np):
with ops.colocate_with(params[p]):
partitioned_result[p] = maybe_normalize(partitioned_result[p])
if not ignore_weights:
# Multiply each partition result with partition weights.
for p in xrange(np):
with ops.colocate_with(params[p]):
if partitioned_weight[p].dtype != partitioned_result[p].dtype:
partitioned_weight[p] = math_ops.cast(partitioned_weight[p],
partitioned_result[p].dtype)
# Reshape partition weights.
ones = array_ops.fill(
array_ops.expand_dims(
array_ops.rank(partitioned_result[p]) - 1, 0), 1)
bcast_weights_shape = array_ops.concat(
[array_ops.shape(partitioned_weight[p]), ones], 0)
orig_weights_shape = partitioned_weight[p].get_shape()
partitioned_weight[p] = array_ops.reshape(partitioned_weight[p],
bcast_weights_shape)
if partitioned_result[p].get_shape().ndims is not None:
partitioned_weight[p].set_shape(orig_weights_shape.concatenate(
[1 for _ in range(
partitioned_result[p].get_shape().ndims - 1)]))
partitioned_result[p] *= partitioned_weight[p]
partitioned_segment_ids = []
for p in xrange(np):
if not ignore_weights:
# Partition segment_ids according to pindices.
p_segment_ids = array_ops.gather(segment_ids, pindices[p])
# Number the p_segment_ids to meet segment_sum's requirements. Note
# that unique_p_segment_ids contains unique segment ids of this
# partiton and these ids' order is unchanged.
unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
p_segment_ids)
partitioned_segment_ids.append(unique_p_segment_ids)
# segment_sum this partition's result.
with ops.colocate_with(params[p]):
partitioned_result[p] = math_ops.segment_sum(
partitioned_result[p], unique_p_segment_idx)
else:
# When ignore weights, we need to get indexs of elements in idx and
# segment_ids.
_, exclude_idx = array_ops.setdiff1d(idx, pindices[p])
all_idx = math_ops.range(array_ops.shape(idx)[0])
_, include_idx = array_ops.setdiff1d(all_idx, exclude_idx)
# Gather segment_ids and idx according to indexs.
p_segment_ids = array_ops.gather(segment_ids, include_idx)
p_idx = array_ops.gather(idx, include_idx)
# Number the p_segment_ids, same as ignore_weights case above.
unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
p_segment_ids)
_, unique_p_idx_idx = array_ops.unique(p_idx)
partitioned_segment_ids.append(unique_p_segment_ids)
with ops.colocate_with(params[p]):
partitioned_result[p] = math_ops.sparse_segment_sum(
partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx)
# Concat each partition's segment_ids and result for final segment_sum.
concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0)
concat_partitioned_result = array_ops.concat(partitioned_result, 0)
return math_ops.unsorted_segment_sum(
concat_partitioned_result, concat_segment_ids,
math_ops.reduce_max(concat_segment_ids) + 1)
# It's important that we compute params[0].shape on the right device
# to avoid data motion.
with ops.colocate_with(params[0]):
params_shape = array_ops.shape(params[0])
ret = array_ops.reshape(ret,
array_ops.concat([
array_ops.shape(ids),
array_ops.slice(params_shape, [1], [-1])
], 0))
# output shape = ids.shape + params[*].shape[1:]
# Normally the reshape is sufficient, but setting shape explicitly
# teaches shape inference that params[1:].get_shape() matters.
ret.set_shape(ids.get_shape().concatenate(element_shape))
return maybe_normalize(ret)
def embedding_lookup_sparse(params, sp_ids, sp_weights,
partition_strategy="mod",
name=None,
combiner=None,
max_norm=None,
use_aggregation=False):
max_norm=None):
"""Computes embeddings for the given ids and weights.
This op assumes that there is at least one id for each row in the dense tensor
@ -381,9 +244,6 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights,
squares of the weights.
max_norm: If not None, each embedding is normalized to have l2 norm equal
to max_norm before combining.
use_aggregation: If True, embeddings belonging to same param are aggregated
on that device first. This option is intented to reduce data transmission
and increase concurrency.
Returns:
A dense tensor representing the combined embeddings for the
@ -458,89 +318,56 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights,
else:
idx = None
if not use_aggregation:
embeddings = embedding_lookup(
params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
if not ignore_weights:
weights = sp_weights.values
if weights.dtype != embeddings.dtype:
weights = math_ops.cast(weights, embeddings.dtype)
# Reshape weights to allow broadcast
ones = array_ops.fill(
array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
0)
orig_weights_shape = weights.get_shape()
weights = array_ops.reshape(weights, bcast_weights_shape)
# Set the weight shape, since after reshaping to bcast_weights_shape,
# the shape becomes None.
if embeddings.get_shape().ndims is not None:
weights.set_shape(orig_weights_shape.concatenate(
[1 for _ in range(embeddings.get_shape().ndims - 1)]))
embeddings *= weights
if combiner == "sum":
embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
elif combiner == "mean":
embeddings = math_ops.segment_sum(embeddings, segment_ids)
weight_sum = math_ops.segment_sum(weights, segment_ids)
embeddings = math_ops.div(embeddings, weight_sum, name=name)
elif combiner == "sqrtn":
embeddings = math_ops.segment_sum(embeddings, segment_ids)
weights_squared = math_ops.pow(weights, 2)
weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
weight_sum_sqrt = math_ops.sqrt(weight_sum)
embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name)
else:
assert False, "Unrecognized combiner"
else:
assert idx is not None
if combiner == "sum":
embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
name=name)
elif combiner == "mean":
embeddings = math_ops.sparse_segment_mean(embeddings, idx,
segment_ids, name=name)
elif combiner == "sqrtn":
embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx,
segment_ids, name=name)
else:
assert False, "Unrecognized combiner"
else:
weights = None if ignore_weights else sp_weights.values
embeddings = embedding_lookup(
params, ids, partition_strategy=partition_strategy, max_norm=max_norm,
use_aggregation=True,
weights=weights, idx=idx, segment_ids=segment_ids)
# Set weights to all one if ignore weights.
if ignore_weights:
weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1)
embeddings = embedding_lookup(
params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
if not ignore_weights:
weights = sp_weights.values
if weights.dtype != embeddings.dtype:
weights = math_ops.cast(weights, embeddings.dtype)
# Reshape weights.
# Reshape weights to allow broadcast
ones = array_ops.fill(
array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
0)
orig_weights_shape = weights.get_shape()
weights = array_ops.reshape(weights, bcast_weights_shape)
# Set the weight shape, since after reshaping to bcast_weights_shape,
# the shape becomes None.
if embeddings.get_shape().ndims is not None:
weights.set_shape(orig_weights_shape.concatenate(
[1 for _ in range(embeddings.get_shape().ndims - 1)]))
if combiner == "mean":
embeddings *= weights
if combiner == "sum":
embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
elif combiner == "mean":
embeddings = math_ops.segment_sum(embeddings, segment_ids)
weight_sum = math_ops.segment_sum(weights, segment_ids)
embeddings = math_ops.div(embeddings, weight_sum)
embeddings = math_ops.div(embeddings, weight_sum, name=name)
elif combiner == "sqrtn":
embeddings = math_ops.segment_sum(embeddings, segment_ids)
weights_squared = math_ops.pow(weights, 2)
weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
weight_sum_sqrt = math_ops.sqrt(weight_sum)
embeddings = math_ops.div(embeddings, weight_sum_sqrt)
elif combiner != "sum":
embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name)
else:
assert False, "Unrecognized combiner"
else:
assert idx is not None
if combiner == "sum":
embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
name=name)
elif combiner == "mean":
embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
name=name)
elif combiner == "sqrtn":
embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx,
segment_ids, name=name)
else:
assert False, "Unrecognized combiner"
return embeddings