Update tf.embedding_lookup
to use partition_strategy
and max_norm
when ids
is ragged. (Prior to this change, these two parameters were ignored for ragged ids
values.)
PiperOrigin-RevId: 293507108 Change-Id: Id5879023d81286cbe79d829d03a303cfca0e1df2
This commit is contained in:
parent
66779177f6
commit
2d663113d9
@ -40,6 +40,7 @@ from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.util import compat
|
||||
@ -644,6 +645,12 @@ class EmbeddingLookupTest(test.TestCase):
|
||||
# same results. Therefore, an exact comparison cannot be made.
|
||||
self.assertAllClose(simple, sharded)
|
||||
|
||||
def testRaggedMaxNorm(self):
|
||||
embeddings = constant_op.constant([[2.0]])
|
||||
ids = ragged_factory_ops.constant([[0, 0], [0]], dtype=dtypes.int32)
|
||||
embedding = embedding_ops.embedding_lookup([embeddings], ids, max_norm=1.0)
|
||||
self.assertAllEqual(embedding, [[[1.0], [1.0]], [[1.0]]])
|
||||
|
||||
|
||||
class EmbeddingLookupSparseTest(test.TestCase):
|
||||
|
||||
|
@ -312,7 +312,10 @@ def embedding_lookup(
|
||||
ValueError: If `params` is empty.
|
||||
"""
|
||||
if isinstance(ids, ragged_tensor.RaggedTensor):
|
||||
return embedding_lookup_ragged(params, ids)
|
||||
return embedding_lookup_ragged(params, ids,
|
||||
partition_strategy=partition_strategy,
|
||||
max_norm=max_norm,
|
||||
name=name)
|
||||
|
||||
return _embedding_lookup_and_transform(
|
||||
params=params,
|
||||
@ -823,7 +826,11 @@ def safe_embedding_lookup_sparse(embedding_weights,
|
||||
return final_result
|
||||
|
||||
|
||||
def embedding_lookup_ragged(embedding_weights, ragged_ids, name=None):
|
||||
def embedding_lookup_ragged(embedding_weights,
|
||||
ragged_ids,
|
||||
partition_strategy="mod",
|
||||
max_norm=None,
|
||||
name=None):
|
||||
"""Look up the ragged ids in a list of embedding tensors.
|
||||
|
||||
Args:
|
||||
@ -832,6 +839,9 @@ def embedding_lookup_ragged(embedding_weights, ragged_ids, name=None):
|
||||
ragged_ids: A 'RaggedTensor' with type 'int32' or 'int64' containing the ids
|
||||
to be looked up in 'embedding_weights' of shape [r0, ..rN]. Values must be
|
||||
in the range '[0, embedding_weights.shape[0]]'.
|
||||
partition_strategy: A string specifying the partitioning strategy.
|
||||
max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
|
||||
than this value.
|
||||
name: A name for the operation (optional)
|
||||
|
||||
Returns:
|
||||
@ -853,7 +863,11 @@ def embedding_lookup_ragged(embedding_weights, ragged_ids, name=None):
|
||||
|
||||
with ops.name_scope(name, "embedding_lookup_ragged") as name:
|
||||
looked_up_ragged = ragged_functional_ops.map_flat_values(
|
||||
array_ops.gather, embedding_weights, ragged_ids)
|
||||
embedding_lookup,
|
||||
params=embedding_weights,
|
||||
ids=ragged_ids,
|
||||
partition_strategy=partition_strategy,
|
||||
max_norm=max_norm)
|
||||
|
||||
return looked_up_ragged
|
||||
|
||||
|
@ -1590,6 +1590,34 @@ class RaggedEmbeddingTest(test_lib.TestCase):
|
||||
ValueError, "The values contained by the inputs have type*"):
|
||||
nn.embedding_lookup_ragged(weights, ragged_ids)
|
||||
|
||||
def testMaxNormForEmbeddings(self):
|
||||
weights = constant_op.constant([[0, 0, 0, 0], [1, 1, 1, 1],
|
||||
[2, 2, 2, 2], [3, 3, 3, 3]],
|
||||
dtype=dtypes.float32)
|
||||
ragged_ids = ragged_factory_ops.constant([[1, 2, 3], [0], [1, 2]],
|
||||
ragged_rank=1)
|
||||
|
||||
actual_embeddings = [
|
||||
nn.embedding_lookup(weights, ragged_ids, max_norm=max_norm)
|
||||
for max_norm in [1, 2, 5]]
|
||||
|
||||
expected_embeddings = (
|
||||
# max_norm = 1
|
||||
[[[.5, .5, .5, .5], [.5, .5, .5, .5], [.5, .5, .5, .5]],
|
||||
[[0, 0, 0, 0]], [[.5, .5, .5, .5], [.5, .5, .5, .5]]],
|
||||
# max_norm = 2
|
||||
[[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0]], [[1, 1, 1, 1], [1, 1, 1, 1]]],
|
||||
# max_norm = 5
|
||||
[[[1, 1, 1, 1], [2, 2, 2, 2], [2.5, 2.5, 2.5, 2.5]],
|
||||
[[0, 0, 0, 0]], [[1, 1, 1, 1], [2, 2, 2, 2]]],
|
||||
)
|
||||
|
||||
for expected, actual in zip(expected_embeddings, actual_embeddings):
|
||||
self.assertAllClose(
|
||||
ragged_factory_ops.constant(expected, dtype=float, ragged_rank=1),
|
||||
actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lib.main()
|
||||
|
Loading…
Reference in New Issue
Block a user