Merge pull request #8650 from suiyuan2009/add-distributed-aggregation-for-embedding_lookup_sparse
Add distributed aggregation for embedding lookup sparse
This commit is contained in:
commit
9a7481383e
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
|
||||
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
||||
|
||||
@ -26,15 +28,18 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
__all__ = [
|
||||
"safe_embedding_lookup_sparse", "scattered_embedding_lookup",
|
||||
"scattered_embedding_lookup_sparse", "embedding_lookup_unique"
|
||||
"scattered_embedding_lookup_sparse", "embedding_lookup_unique",
|
||||
"embedding_lookup_sparse_with_distributed_aggregation"
|
||||
]
|
||||
|
||||
|
||||
@ -548,3 +553,326 @@ def _sampled_scattered_embedding_lookup_sparse(params,
|
||||
return math_ops.unsorted_segment_sum(embeddings, segment_ids,
|
||||
num_segments=num_segments,
|
||||
name=name_scope)
|
||||
|
||||
|
||||
def embedding_lookup_sparse_with_distributed_aggregation(params, sp_ids,
|
||||
sp_weights, partition_strategy="mod", name=None, combiner=None,
|
||||
max_norm=None):
|
||||
"""Computes embeddings for the given ids and weights.
|
||||
|
||||
Embeddings belonging to same param are aggregated on that device first. This
|
||||
op is intended to decrease data transmission and improve parallelism. See
|
||||
`tf.nn.embedding_lookup_sparse` for the functionality and example of this op.
|
||||
|
||||
Args:
|
||||
params: A single tensor representing the complete embedding tensor,
|
||||
or a list of P tensors all of same shape except for the first dimension,
|
||||
representing sharded embedding tensors. Alternatively, a
|
||||
`PartitionedVariable`, created by partitioning along dimension 0. Each
|
||||
element must be appropriately sized for the given `partition_strategy`.
|
||||
sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
|
||||
where N is typically batch size and M is arbitrary.
|
||||
sp_weights: either a SparseTensor of float / double weights, or None to
|
||||
indicate all weights should be taken to be 1. If specified, sp_weights
|
||||
must have exactly the same shape and indices as sp_ids.
|
||||
partition_strategy: A string specifying the partitioning strategy, relevant
|
||||
if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
|
||||
is `"mod"`. See `tf.nn.embedding_lookup` for more details.
|
||||
name: Optional name for the op.
|
||||
combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
|
||||
and "sum" are supported.
|
||||
"sum" computes the weighted sum of the embedding results for each row.
|
||||
"mean" is the weighted sum divided by the total weight.
|
||||
"sqrtn" is the weighted sum divided by the square root of the sum of the
|
||||
squares of the weights.
|
||||
max_norm: If not None, each embedding is normalized to have l2 norm equal
|
||||
to max_norm before combining.
|
||||
|
||||
Returns:
|
||||
A dense tensor representing the combined embeddings for the
|
||||
sparse ids. For each row in the dense tensor represented by sp_ids, the op
|
||||
looks up the embeddings for all ids in that row, multiplies them by the
|
||||
corresponding weight, and combines these embeddings as specified.
|
||||
|
||||
Raises:
|
||||
TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
|
||||
None nor SparseTensor.
|
||||
ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
|
||||
"""
|
||||
if combiner is None:
|
||||
logging.warn("The default value of combiner will change from \"mean\" "
|
||||
"to \"sqrtn\" after 2016/11/01.")
|
||||
combiner = "mean"
|
||||
if combiner not in ("mean", "sqrtn", "sum"):
|
||||
raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
|
||||
if isinstance(params, variables.PartitionedVariable):
|
||||
params = list(params) # Iterate to get the underlying Variables.
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
if not isinstance(sp_ids, sparse_tensor.SparseTensor):
|
||||
raise TypeError("sp_ids must be SparseTensor")
|
||||
ignore_weights = sp_weights is None
|
||||
if not ignore_weights:
|
||||
if not isinstance(sp_weights, sparse_tensor.SparseTensor):
|
||||
raise TypeError("sp_weights must be either None or SparseTensor")
|
||||
sp_ids.values.get_shape().assert_is_compatible_with(
|
||||
sp_weights.values.get_shape())
|
||||
sp_ids.indices.get_shape().assert_is_compatible_with(
|
||||
sp_weights.indices.get_shape())
|
||||
sp_ids.dense_shape.get_shape().assert_is_compatible_with(
|
||||
sp_weights.dense_shape.get_shape())
|
||||
# TODO(yleon): Add enhanced node assertions to verify that sp_ids and
|
||||
# sp_weights have equal indices and shapes.
|
||||
|
||||
with ops.name_scope(name, "embedding_lookup_sparse",
|
||||
params + [sp_ids]) as name:
|
||||
segment_ids = sp_ids.indices[:, 0]
|
||||
if segment_ids.dtype != dtypes.int32:
|
||||
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
|
||||
|
||||
ids = sp_ids.values
|
||||
if ignore_weights:
|
||||
ids, idx = array_ops.unique(ids)
|
||||
else:
|
||||
idx = None
|
||||
|
||||
weights = None if ignore_weights else sp_weights.values
|
||||
embeddings = _embedding_lookup_with_distributed_aggregation(
|
||||
params, ids, partition_strategy=partition_strategy, max_norm=max_norm,
|
||||
weights=weights, idx=idx, segment_ids=segment_ids)
|
||||
# Set weights to all one if ignore weights.
|
||||
if ignore_weights:
|
||||
weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1)
|
||||
if weights.dtype != embeddings.dtype:
|
||||
weights = math_ops.cast(weights, embeddings.dtype)
|
||||
# Reshape weights.
|
||||
ones = array_ops.fill(
|
||||
array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
|
||||
bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
|
||||
0)
|
||||
orig_weights_shape = weights.get_shape()
|
||||
weights = array_ops.reshape(weights, bcast_weights_shape)
|
||||
if embeddings.get_shape().ndims is not None:
|
||||
weights.set_shape(orig_weights_shape.concatenate(
|
||||
[1 for _ in range(embeddings.get_shape().ndims - 1)]))
|
||||
|
||||
if combiner == "mean":
|
||||
weight_sum = math_ops.segment_sum(weights, segment_ids)
|
||||
embeddings = math_ops.div(embeddings, weight_sum)
|
||||
elif combiner == "sqrtn":
|
||||
weights_squared = math_ops.pow(weights, 2)
|
||||
weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
|
||||
weight_sum_sqrt = math_ops.sqrt(weight_sum)
|
||||
embeddings = math_ops.div(embeddings, weight_sum_sqrt)
|
||||
elif combiner != "sum":
|
||||
assert False, "Unrecognized combiner"
|
||||
return embeddings
|
||||
|
||||
|
||||
def _do_gather(params, ids, validate_indices=True, name=None):
|
||||
"""Deals with doing gather differently for resource variables."""
|
||||
if isinstance(params, resource_variable_ops.ResourceVariable):
|
||||
return params.sparse_read(ids, name=name)
|
||||
return array_ops.gather(
|
||||
params, ids, name=name, validate_indices=validate_indices)
|
||||
|
||||
|
||||
def _embedding_lookup_with_distributed_aggregation(params, ids,
|
||||
partition_strategy="mod", name=None, validate_indices=True, max_norm=None,
|
||||
weights=None, idx=None, segment_ids=None):
|
||||
""" Lookup helper for embedding_lookup_sparse_with_distributed_aggregation."""
|
||||
if params is None or params == []: # pylint: disable=g-explicit-bool-comparison
|
||||
raise ValueError("Need at least one param")
|
||||
if isinstance(params, variables.PartitionedVariable):
|
||||
params = list(params) # Iterate to get the underlying Variables.
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
def maybe_normalize(x):
|
||||
if max_norm is not None:
|
||||
if x.get_shape().ndims is not None:
|
||||
ndims = x.get_shape().ndims
|
||||
else:
|
||||
ndims = array_ops.size(array_ops.shape(x))
|
||||
return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
|
||||
return x
|
||||
with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation",
|
||||
params + [ids]) as name:
|
||||
np = len(params) # Number of partitions
|
||||
# Preserve the resource variable status to avoid accidental dense reads.
|
||||
if not any(isinstance(p, resource_variable_ops.ResourceVariable)
|
||||
for p in params):
|
||||
params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
|
||||
if np == 1:
|
||||
with ops.colocate_with(params[0]):
|
||||
ret = maybe_normalize(
|
||||
_do_gather(
|
||||
params[0], ids, validate_indices=validate_indices))
|
||||
ignore_weights = weights is None
|
||||
if not ignore_weights:
|
||||
if weights.dtype != ret.dtype:
|
||||
weights = math_ops.cast(weights, ret.dtype)
|
||||
# Reshape to allow broadcast
|
||||
ones = array_ops.fill(
|
||||
array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1)
|
||||
bcast_weights_shape = array_ops.concat(
|
||||
[array_ops.shape(weights), ones], 0)
|
||||
orig_weights_shape = weights.get_shape()
|
||||
weights = array_ops.reshape(weights, bcast_weights_shape)
|
||||
# Set weights shape after reshape
|
||||
if ret.get_shape().ndims is not None:
|
||||
weights.set_shape(orig_weights_shape.concatenate(
|
||||
[1 for _ in range(ret.get_shape().ndims - 1)]))
|
||||
ret *= weights
|
||||
return math_ops.segment_sum(ret, segment_ids, name=name)
|
||||
else:
|
||||
return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name)
|
||||
else:
|
||||
ids = ops.convert_to_tensor(ids, name="ids")
|
||||
flat_ids = array_ops.reshape(ids, [-1])
|
||||
original_indices = math_ops.range(array_ops.size(flat_ids))
|
||||
|
||||
# Create p_assignments and set new_ids depending on the strategy.
|
||||
if partition_strategy == "mod":
|
||||
p_assignments = flat_ids % np
|
||||
new_ids = flat_ids // np
|
||||
elif partition_strategy == "div":
|
||||
# Compute num_total_ids as the sum of dim-0 of params, then assign to
|
||||
# partitions based on a constant number of ids per partition. Optimize
|
||||
# if we already know the full shape statically.
|
||||
dim_0_size = params[0].get_shape()[0]
|
||||
for p in xrange(1, np):
|
||||
dim_0_size += params[p].get_shape()[0]
|
||||
if dim_0_size.value:
|
||||
num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
|
||||
else:
|
||||
dim_0_sizes = []
|
||||
for p in xrange(np):
|
||||
if params[p].get_shape()[0].value is not None:
|
||||
dim_0_sizes.append(params[p].get_shape()[0].value)
|
||||
else:
|
||||
with ops.colocate_with(params[p]):
|
||||
dim_0_sizes.append(array_ops.shape(params[p])[0])
|
||||
num_total_ids = math_ops.reduce_sum(
|
||||
math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
|
||||
ids_per_partition = num_total_ids // np
|
||||
extras = num_total_ids % np
|
||||
|
||||
p_assignments = math_ops.maximum(
|
||||
flat_ids // (ids_per_partition + 1),
|
||||
(flat_ids - extras) // ids_per_partition)
|
||||
|
||||
# Emulate a conditional using a boolean indicator tensor
|
||||
is_in_first_extras_partitions = math_ops.cast(
|
||||
p_assignments < extras, flat_ids.dtype)
|
||||
new_ids = (
|
||||
is_in_first_extras_partitions * (
|
||||
flat_ids % (ids_per_partition + 1)) +
|
||||
(1 - is_in_first_extras_partitions) * (
|
||||
(flat_ids - extras) % ids_per_partition))
|
||||
else:
|
||||
raise ValueError("Unrecognized partition strategy: " +
|
||||
partition_strategy)
|
||||
|
||||
# Cast partition assignments to int32 for use in dynamic_partition.
|
||||
# There really should not be more than 2^32 partitions.
|
||||
p_assignments = math_ops.cast(p_assignments, dtypes.int32)
|
||||
# Partition list of ids based on assignments into np separate lists
|
||||
gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
|
||||
# Similarly, partition the original indices.
|
||||
pindices = data_flow_ops.dynamic_partition(original_indices,
|
||||
p_assignments, np)
|
||||
# Do np separate lookups, finding embeddings for plist[p] in params[p]
|
||||
partitioned_result = []
|
||||
for p in xrange(np):
|
||||
with ops.colocate_with(params[p]):
|
||||
partitioned_result.append(
|
||||
_do_gather(params[p], gather_ids[p],
|
||||
validate_indices=validate_indices))
|
||||
|
||||
ignore_weights = weights is None
|
||||
if not ignore_weights:
|
||||
# Partition weights according to pindices.
|
||||
partitioned_weight = []
|
||||
for p in xrange(np):
|
||||
partitioned_weight.append(array_ops.gather(weights, pindices[p]))
|
||||
# Reshape each partition result.
|
||||
element_shape = params[0].get_shape()[1:]
|
||||
for p in params[1:]:
|
||||
element_shape = element_shape.merge_with(p.get_shape()[1:])
|
||||
if element_shape.is_fully_defined():
|
||||
for p in xrange(np):
|
||||
with ops.colocate_with(params[p]):
|
||||
partitioned_result[p] = array_ops.reshape(partitioned_result[p],
|
||||
array_ops.concat(
|
||||
[array_ops.shape(pindices[p]), element_shape], 0))
|
||||
else:
|
||||
with ops.colocate_with(params[0]):
|
||||
params_shape = array_ops.shape(params[0])
|
||||
for p in xrange(np):
|
||||
with ops.colocate_with(params[p]):
|
||||
partitioned_result[p] = array_ops.reshape(partitioned_result[p],
|
||||
array_ops.concat([array_ops.shape(pindices[p]),
|
||||
array_ops.slice(params_shape, [1], [-1])], 0))
|
||||
# Normalize each partition result.
|
||||
for p in xrange(np):
|
||||
with ops.colocate_with(params[p]):
|
||||
partitioned_result[p] = maybe_normalize(partitioned_result[p])
|
||||
if not ignore_weights:
|
||||
# Multiply each partition result with partition weights.
|
||||
for p in xrange(np):
|
||||
with ops.colocate_with(params[p]):
|
||||
if partitioned_weight[p].dtype != partitioned_result[p].dtype:
|
||||
partitioned_weight[p] = math_ops.cast(partitioned_weight[p],
|
||||
partitioned_result[p].dtype)
|
||||
# Reshape partition weights.
|
||||
ones = array_ops.fill(
|
||||
array_ops.expand_dims(
|
||||
array_ops.rank(partitioned_result[p]) - 1, 0), 1)
|
||||
bcast_weights_shape = array_ops.concat(
|
||||
[array_ops.shape(partitioned_weight[p]), ones], 0)
|
||||
orig_weights_shape = partitioned_weight[p].get_shape()
|
||||
partitioned_weight[p] = array_ops.reshape(partitioned_weight[p],
|
||||
bcast_weights_shape)
|
||||
if partitioned_result[p].get_shape().ndims is not None:
|
||||
partitioned_weight[p].set_shape(orig_weights_shape.concatenate(
|
||||
[1 for _ in range(
|
||||
partitioned_result[p].get_shape().ndims - 1)]))
|
||||
partitioned_result[p] *= partitioned_weight[p]
|
||||
partitioned_segment_ids = []
|
||||
for p in xrange(np):
|
||||
if not ignore_weights:
|
||||
# Partition segment_ids according to pindices.
|
||||
p_segment_ids = array_ops.gather(segment_ids, pindices[p])
|
||||
# Number the p_segment_ids to meet segment_sum's requirements. Note
|
||||
# that unique_p_segment_ids contains unique segment ids of this
|
||||
# partiton and these ids' order is unchanged.
|
||||
unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
|
||||
p_segment_ids)
|
||||
partitioned_segment_ids.append(unique_p_segment_ids)
|
||||
# segment_sum this partition's result.
|
||||
with ops.colocate_with(params[p]):
|
||||
partitioned_result[p] = math_ops.segment_sum(
|
||||
partitioned_result[p], unique_p_segment_idx)
|
||||
else:
|
||||
# When ignore weights, we need to get indexs of elements in idx and
|
||||
# segment_ids.
|
||||
_, exclude_idx = array_ops.setdiff1d(idx, pindices[p])
|
||||
all_idx = math_ops.range(array_ops.shape(idx)[0])
|
||||
_, include_idx = array_ops.setdiff1d(all_idx, exclude_idx)
|
||||
# Gather segment_ids and idx according to indexs.
|
||||
p_segment_ids = array_ops.gather(segment_ids, include_idx)
|
||||
p_idx = array_ops.gather(idx, include_idx)
|
||||
# Number the p_segment_ids, same as ignore_weights case above.
|
||||
unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
|
||||
p_segment_ids)
|
||||
_, unique_p_idx_idx = array_ops.unique(p_idx)
|
||||
partitioned_segment_ids.append(unique_p_segment_ids)
|
||||
with ops.colocate_with(params[p]):
|
||||
partitioned_result[p] = math_ops.sparse_segment_sum(
|
||||
partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx)
|
||||
# Concat each partition's segment_ids and result for final segment_sum.
|
||||
concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0)
|
||||
concat_partitioned_result = array_ops.concat(partitioned_result, 0)
|
||||
return math_ops.unsorted_segment_sum(
|
||||
concat_partitioned_result, concat_segment_ids,
|
||||
math_ops.reduce_max(concat_segment_ids) + 1, name=name)
|
||||
|
@ -32,9 +32,11 @@ from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class SafeEmbeddingLookupSparseTest(test.TestCase):
|
||||
@ -563,5 +565,224 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
|
||||
self.assertAllClose(result.eval(), result_abc.eval())
|
||||
|
||||
|
||||
def _PName(param_id):
|
||||
return "p" + str(param_id)
|
||||
|
||||
|
||||
def _EmbeddingParams(num_shards,
|
||||
vocab_size,
|
||||
dtype=dtypes.float32,
|
||||
shape=None,
|
||||
use_shapeless_placeholder=False):
|
||||
p = []
|
||||
params = {}
|
||||
feed_dict = {}
|
||||
if not shape:
|
||||
shape = [10]
|
||||
for i in range(num_shards):
|
||||
shard_shape = [vocab_size // num_shards] + shape
|
||||
if i < vocab_size % num_shards: # Excess goes evenly on the first shards
|
||||
shard_shape[0] += 1
|
||||
|
||||
param_name = _PName(i)
|
||||
|
||||
if use_shapeless_placeholder:
|
||||
param = array_ops.placeholder(dtype, shape=None, name=param_name)
|
||||
else:
|
||||
param = constant_op.constant(
|
||||
1.0, shape=shard_shape, dtype=dtype, name=param_name)
|
||||
p.append(param)
|
||||
np_type = "f" if dtype == dtypes.float32 else "d"
|
||||
val = (np.random.rand(*shard_shape).astype(np_type)) + 1
|
||||
params[param_name + ":0"] = val
|
||||
feed_dict[param.name] = val
|
||||
return p, params, feed_dict
|
||||
|
||||
|
||||
def _EmbeddingResult(params,
|
||||
id_vals,
|
||||
num_shards,
|
||||
vocab_size,
|
||||
partition_strategy="mod",
|
||||
weight_vals=None):
|
||||
if weight_vals is None:
|
||||
weight_vals = np.copy(id_vals)
|
||||
weight_vals.fill(1)
|
||||
values = []
|
||||
weights = []
|
||||
weights_squared = []
|
||||
for ids, wts in zip(id_vals, weight_vals):
|
||||
value_aggregation = None
|
||||
weight_aggregation = None
|
||||
squared_weight_aggregation = None
|
||||
if isinstance(ids, compat.integral_types):
|
||||
ids = [ids]
|
||||
wts = [wts]
|
||||
for i, weight_value in zip(ids, wts):
|
||||
if partition_strategy == "mod":
|
||||
val = np.copy(params[_PName(i % num_shards) + ":0"][
|
||||
i // num_shards, :]) * weight_value
|
||||
elif partition_strategy == "div":
|
||||
ids_per_partition, extras = divmod(vocab_size, num_shards)
|
||||
threshold = extras * (ids_per_partition + 1)
|
||||
if i < threshold:
|
||||
partition = i // (ids_per_partition + 1)
|
||||
offset = i % (ids_per_partition + 1)
|
||||
else:
|
||||
partition = extras + (i - threshold) // ids_per_partition
|
||||
offset = (i - threshold) % ids_per_partition
|
||||
val = np.copy(params[_PName(partition) + ":0"][
|
||||
offset, :]) * weight_value
|
||||
else:
|
||||
assert False
|
||||
if value_aggregation is None:
|
||||
assert weight_aggregation is None
|
||||
assert squared_weight_aggregation is None
|
||||
value_aggregation = val
|
||||
weight_aggregation = weight_value
|
||||
squared_weight_aggregation = weight_value * weight_value
|
||||
else:
|
||||
assert weight_aggregation is not None
|
||||
assert squared_weight_aggregation is not None
|
||||
value_aggregation += val
|
||||
weight_aggregation += weight_value
|
||||
squared_weight_aggregation += weight_value * weight_value
|
||||
values.append(value_aggregation)
|
||||
weights.append(weight_aggregation)
|
||||
weights_squared.append(squared_weight_aggregation)
|
||||
values = np.array(values).astype(np.float32)
|
||||
weights = np.array(weights).astype(np.float32)
|
||||
weights_squared = np.array(weights_squared).astype(np.float32)
|
||||
return values, weights, weights_squared
|
||||
|
||||
|
||||
class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
|
||||
|
||||
def _RandomIdsAndWeights(self, batch_size, vocab_size):
|
||||
max_val_per_entry = 6
|
||||
vals_per_batch_entry = np.random.randint(
|
||||
1, max_val_per_entry, size=batch_size)
|
||||
num_vals = np.sum(vals_per_batch_entry)
|
||||
|
||||
ids = np.random.randint(vocab_size, size=num_vals)
|
||||
weights = 1 + np.random.rand(num_vals)
|
||||
|
||||
indices = []
|
||||
for batch_entry, num_val in enumerate(vals_per_batch_entry):
|
||||
for val_index in range(num_val):
|
||||
indices.append([batch_entry, val_index])
|
||||
|
||||
shape = [batch_size, max_val_per_entry]
|
||||
|
||||
sp_ids = sparse_tensor_lib.SparseTensor(
|
||||
constant_op.constant(indices, dtypes.int64),
|
||||
constant_op.constant(ids, dtypes.int32),
|
||||
constant_op.constant(shape, dtypes.int64))
|
||||
sp_weights = sparse_tensor_lib.SparseTensor(
|
||||
constant_op.constant(indices, dtypes.int64),
|
||||
constant_op.constant(weights, dtypes.float32),
|
||||
constant_op.constant(shape, dtypes.int64))
|
||||
|
||||
return sp_ids, sp_weights, ids, weights, vals_per_batch_entry
|
||||
|
||||
def _GroupByBatchEntry(self, vals, vals_per_batch_entry):
|
||||
grouped_vals = []
|
||||
index = 0
|
||||
for num_val in vals_per_batch_entry:
|
||||
grouped_vals.append(list(vals[index:(index + num_val)]))
|
||||
index += num_val
|
||||
return grouped_vals
|
||||
|
||||
def testEmbeddingLookupSparse(self):
|
||||
vocab_size = 13
|
||||
batch_size = 10
|
||||
param_shape = [2, 5]
|
||||
expected_lookup_result_shape = [None] + param_shape
|
||||
|
||||
sp_ids, sp_weights, ids, weights, vals_per_batch_entry = (
|
||||
self._RandomIdsAndWeights(batch_size, vocab_size))
|
||||
|
||||
grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry)
|
||||
grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry)
|
||||
grouped_ignored_weights = self._GroupByBatchEntry(
|
||||
np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
|
||||
|
||||
for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 5],
|
||||
["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
|
||||
[True, False]):
|
||||
|
||||
with self.test_session():
|
||||
p, params, feed_dict = _EmbeddingParams(
|
||||
num_shards, vocab_size, shape=param_shape, dtype=dtype)
|
||||
embedding_sum = \
|
||||
embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
|
||||
p,
|
||||
sp_ids,
|
||||
None if ignore_weights else sp_weights,
|
||||
combiner=combiner)
|
||||
|
||||
self.assertEqual(embedding_sum.get_shape().as_list(),
|
||||
expected_lookup_result_shape)
|
||||
|
||||
tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
|
||||
|
||||
np_embedding_sum, np_weight_sum, np_weight_sq_sum = _EmbeddingResult(
|
||||
params,
|
||||
grouped_ids,
|
||||
num_shards,
|
||||
vocab_size,
|
||||
weight_vals=grouped_ignored_weights if ignore_weights else
|
||||
grouped_weights)
|
||||
if combiner == "mean":
|
||||
np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1))
|
||||
if combiner == "sqrtn":
|
||||
np_embedding_sum /= np.reshape(
|
||||
np.sqrt(np_weight_sq_sum), (batch_size, 1, 1))
|
||||
self.assertAllClose(np_embedding_sum, tf_embedding_sum)
|
||||
|
||||
def testGradientsEmbeddingLookupSparse(self):
|
||||
vocab_size = 12
|
||||
batch_size = 4
|
||||
param_shape = [2, 3]
|
||||
sp_ids, sp_weights, _, _, _ = (
|
||||
self._RandomIdsAndWeights(batch_size, vocab_size))
|
||||
|
||||
for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 3],
|
||||
["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
|
||||
[True, False]):
|
||||
with self.test_session():
|
||||
x, params, _ = _EmbeddingParams(
|
||||
num_shards, vocab_size, shape=param_shape, dtype=dtype)
|
||||
|
||||
y = embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
|
||||
x,
|
||||
sp_ids,
|
||||
None if ignore_weights else sp_weights,
|
||||
combiner=combiner)
|
||||
x_name = [_PName(i) for i in range(num_shards)]
|
||||
x_init_value = [params[x_n + ":0"] for x_n in x_name]
|
||||
x_shape = [i.shape for i in x_init_value]
|
||||
y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:])
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
x, x_shape, y, y_shape, x_init_value=x_init_value)
|
||||
self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
|
||||
|
||||
def testIncompatibleShapes(self):
|
||||
with self.test_session():
|
||||
x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
|
||||
sp_ids = sparse_tensor_lib.SparseTensor(
|
||||
constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
|
||||
constant_op.constant([0, 1, 2], dtypes.int32),
|
||||
constant_op.constant([2, 2], dtypes.int64))
|
||||
sp_weights = sparse_tensor_lib.SparseTensor(
|
||||
constant_op.constant([[0, 0], [0, 1]], dtypes.int64),
|
||||
constant_op.constant([12.0, 5.0], dtypes.float32),
|
||||
constant_op.constant([1, 2], dtypes.int64))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
|
||||
x, sp_ids, sp_weights, combiner="mean")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user