Update tf.embedding_lookup to use partition_strategy and max_norm when ids is ragged. (Prior to this change, these two parameters were ignored for ragged ids values.)

PiperOrigin-RevId: 293507108
Change-Id: Id5879023d81286cbe79d829d03a303cfca0e1df2
This commit is contained in:
Edward Loper 2020-02-05 19:46:37 -08:00 committed by TensorFlower Gardener
parent 66779177f6
commit 2d663113d9
3 changed files with 52 additions and 3 deletions

View File

@ -40,6 +40,7 @@ from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import state_ops from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging from tensorflow.python.platform import tf_logging
from tensorflow.python.util import compat from tensorflow.python.util import compat
@ -644,6 +645,12 @@ class EmbeddingLookupTest(test.TestCase):
# same results. Therefore, an exact comparison cannot be made. # same results. Therefore, an exact comparison cannot be made.
self.assertAllClose(simple, sharded) self.assertAllClose(simple, sharded)
def testRaggedMaxNorm(self):
embeddings = constant_op.constant([[2.0]])
ids = ragged_factory_ops.constant([[0, 0], [0]], dtype=dtypes.int32)
embedding = embedding_ops.embedding_lookup([embeddings], ids, max_norm=1.0)
self.assertAllEqual(embedding, [[[1.0], [1.0]], [[1.0]]])
class EmbeddingLookupSparseTest(test.TestCase): class EmbeddingLookupSparseTest(test.TestCase):

View File

@ -312,7 +312,10 @@ def embedding_lookup(
ValueError: If `params` is empty. ValueError: If `params` is empty.
""" """
if isinstance(ids, ragged_tensor.RaggedTensor): if isinstance(ids, ragged_tensor.RaggedTensor):
return embedding_lookup_ragged(params, ids) return embedding_lookup_ragged(params, ids,
partition_strategy=partition_strategy,
max_norm=max_norm,
name=name)
return _embedding_lookup_and_transform( return _embedding_lookup_and_transform(
params=params, params=params,
@ -823,7 +826,11 @@ def safe_embedding_lookup_sparse(embedding_weights,
return final_result return final_result
def embedding_lookup_ragged(embedding_weights, ragged_ids, name=None): def embedding_lookup_ragged(embedding_weights,
ragged_ids,
partition_strategy="mod",
max_norm=None,
name=None):
"""Look up the ragged ids in a list of embedding tensors. """Look up the ragged ids in a list of embedding tensors.
Args: Args:
@ -832,6 +839,9 @@ def embedding_lookup_ragged(embedding_weights, ragged_ids, name=None):
ragged_ids: A 'RaggedTensor' with type 'int32' or 'int64' containing the ids ragged_ids: A 'RaggedTensor' with type 'int32' or 'int64' containing the ids
to be looked up in 'embedding_weights' of shape [r0, ..rN]. Values must be to be looked up in 'embedding_weights' of shape [r0, ..rN]. Values must be
in the range '[0, embedding_weights.shape[0]]'. in the range '[0, embedding_weights.shape[0]]'.
partition_strategy: A string specifying the partitioning strategy.
max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
than this value.
name: A name for the operation (optional) name: A name for the operation (optional)
Returns: Returns:
@ -853,7 +863,11 @@ def embedding_lookup_ragged(embedding_weights, ragged_ids, name=None):
with ops.name_scope(name, "embedding_lookup_ragged") as name: with ops.name_scope(name, "embedding_lookup_ragged") as name:
looked_up_ragged = ragged_functional_ops.map_flat_values( looked_up_ragged = ragged_functional_ops.map_flat_values(
array_ops.gather, embedding_weights, ragged_ids) embedding_lookup,
params=embedding_weights,
ids=ragged_ids,
partition_strategy=partition_strategy,
max_norm=max_norm)
return looked_up_ragged return looked_up_ragged

View File

@ -1590,6 +1590,34 @@ class RaggedEmbeddingTest(test_lib.TestCase):
ValueError, "The values contained by the inputs have type*"): ValueError, "The values contained by the inputs have type*"):
nn.embedding_lookup_ragged(weights, ragged_ids) nn.embedding_lookup_ragged(weights, ragged_ids)
def testMaxNormForEmbeddings(self):
weights = constant_op.constant([[0, 0, 0, 0], [1, 1, 1, 1],
[2, 2, 2, 2], [3, 3, 3, 3]],
dtype=dtypes.float32)
ragged_ids = ragged_factory_ops.constant([[1, 2, 3], [0], [1, 2]],
ragged_rank=1)
actual_embeddings = [
nn.embedding_lookup(weights, ragged_ids, max_norm=max_norm)
for max_norm in [1, 2, 5]]
expected_embeddings = (
# max_norm = 1
[[[.5, .5, .5, .5], [.5, .5, .5, .5], [.5, .5, .5, .5]],
[[0, 0, 0, 0]], [[.5, .5, .5, .5], [.5, .5, .5, .5]]],
# max_norm = 2
[[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]],
[[0, 0, 0, 0]], [[1, 1, 1, 1], [1, 1, 1, 1]]],
# max_norm = 5
[[[1, 1, 1, 1], [2, 2, 2, 2], [2.5, 2.5, 2.5, 2.5]],
[[0, 0, 0, 0]], [[1, 1, 1, 1], [2, 2, 2, 2]]],
)
for expected, actual in zip(expected_embeddings, actual_embeddings):
self.assertAllClose(
ragged_factory_ops.constant(expected, dtype=float, ragged_rank=1),
actual)
if __name__ == "__main__": if __name__ == "__main__":
test_lib.main() test_lib.main()