diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index fec6c310341..b9d9d125d7c 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -40,6 +40,7 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging from tensorflow.python.util import compat @@ -644,6 +645,12 @@ class EmbeddingLookupTest(test.TestCase): # same results. Therefore, an exact comparison cannot be made. self.assertAllClose(simple, sharded) + def testRaggedMaxNorm(self): + embeddings = constant_op.constant([[2.0]]) + ids = ragged_factory_ops.constant([[0, 0], [0]], dtype=dtypes.int32) + embedding = embedding_ops.embedding_lookup([embeddings], ids, max_norm=1.0) + self.assertAllEqual(embedding, [[[1.0], [1.0]], [[1.0]]]) + class EmbeddingLookupSparseTest(test.TestCase): diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 69a19e77760..7af8b2dc32a 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -312,7 +312,10 @@ def embedding_lookup( ValueError: If `params` is empty. """ if isinstance(ids, ragged_tensor.RaggedTensor): - return embedding_lookup_ragged(params, ids) + return embedding_lookup_ragged(params, ids, + partition_strategy=partition_strategy, + max_norm=max_norm, + name=name) return _embedding_lookup_and_transform( params=params, @@ -823,7 +826,11 @@ def safe_embedding_lookup_sparse(embedding_weights, return final_result -def embedding_lookup_ragged(embedding_weights, ragged_ids, name=None): +def embedding_lookup_ragged(embedding_weights, + ragged_ids, + partition_strategy="mod", + max_norm=None, + name=None): """Look up the ragged ids in a list of embedding tensors. Args: @@ -832,6 +839,9 @@ def embedding_lookup_ragged(embedding_weights, ragged_ids, name=None): ragged_ids: A 'RaggedTensor' with type 'int32' or 'int64' containing the ids to be looked up in 'embedding_weights' of shape [r0, ..rN]. Values must be in the range '[0, embedding_weights.shape[0]]'. + partition_strategy: A string specifying the partitioning strategy. + max_norm: If not `None`, each embedding is clipped if its l2-norm is larger + than this value. name: A name for the operation (optional) Returns: @@ -853,7 +863,11 @@ def embedding_lookup_ragged(embedding_weights, ragged_ids, name=None): with ops.name_scope(name, "embedding_lookup_ragged") as name: looked_up_ragged = ragged_functional_ops.map_flat_values( - array_ops.gather, embedding_weights, ragged_ids) + embedding_lookup, + params=embedding_weights, + ids=ragged_ids, + partition_strategy=partition_strategy, + max_norm=max_norm) return looked_up_ragged diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 53efadc936f..30df5cf7c47 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -1590,6 +1590,34 @@ class RaggedEmbeddingTest(test_lib.TestCase): ValueError, "The values contained by the inputs have type*"): nn.embedding_lookup_ragged(weights, ragged_ids) + def testMaxNormForEmbeddings(self): + weights = constant_op.constant([[0, 0, 0, 0], [1, 1, 1, 1], + [2, 2, 2, 2], [3, 3, 3, 3]], + dtype=dtypes.float32) + ragged_ids = ragged_factory_ops.constant([[1, 2, 3], [0], [1, 2]], + ragged_rank=1) + + actual_embeddings = [ + nn.embedding_lookup(weights, ragged_ids, max_norm=max_norm) + for max_norm in [1, 2, 5]] + + expected_embeddings = ( + # max_norm = 1 + [[[.5, .5, .5, .5], [.5, .5, .5, .5], [.5, .5, .5, .5]], + [[0, 0, 0, 0]], [[.5, .5, .5, .5], [.5, .5, .5, .5]]], + # max_norm = 2 + [[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], + [[0, 0, 0, 0]], [[1, 1, 1, 1], [1, 1, 1, 1]]], + # max_norm = 5 + [[[1, 1, 1, 1], [2, 2, 2, 2], [2.5, 2.5, 2.5, 2.5]], + [[0, 0, 0, 0]], [[1, 1, 1, 1], [2, 2, 2, 2]]], + ) + + for expected, actual in zip(expected_embeddings, actual_embeddings): + self.assertAllClose( + ragged_factory_ops.constant(expected, dtype=float, ragged_rank=1), + actual) + if __name__ == "__main__": test_lib.main()