diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index 2bd21fb01d1..057da9d7afa 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -547,6 +547,31 @@ class EmbeddingLookupTest(test.TestCase): sharded = embedding_ops.embedding_lookup(split_params, ids).eval() self.assertAllEqual(simple, sharded) + def testHigherRankMaxNorm(self): + np.random.seed(8) + with self.test_session(): + for params_shape in (12,), (6, 3): + params = 2 * np.ones(params_shape) + params_norm = params / np.sqrt( + np.sum(params*params, tuple(range(params.ndim)[1:]), keepdims=True)) + for ids_shape in (), (3), (4, 3), (2, 3, 4): + ids = np.random.randint( + params.shape[0], size=np.prod(ids_shape, dtype=np.int64)).reshape( + ids_shape) + # Compare nonsharded to gather + simple = embedding_ops.embedding_lookup( + params, ids, max_norm=1.0).eval() + self.assertAllEqual(simple, array_ops.gather(params_norm, ids).eval()) + # Run a few random sharded versions + for procs in 1, 2, 3: + stride = procs * math_ops.range(params.shape[0] // procs) + split_params = [ + array_ops.gather(params, stride + p) for p in xrange(procs) + ] + sharded = embedding_ops.embedding_lookup( + split_params, ids, max_norm=1.0).eval() + self.assertAllEqual(simple, sharded) + class EmbeddingLookupSparseTest(test.TestCase): diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 315e7d4b43c..6930f9af05f 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -103,14 +103,25 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, params = list(params) # Iterate to get the underlying Variables. if not isinstance(params, list): params = [params] + def maybe_normalize(x): - if max_norm is not None: - if x.get_shape().ndims is not None: - ndims = x.get_shape().ndims - else: - ndims = array_ops.size(array_ops.shape(x)) - return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims))) - return x + """Normalizes the embeddings in x if max_norm is not None.""" + if max_norm is None: + return x + static = True + ids_rank = ops.convert_to_tensor(ids).get_shape().ndims + if ids_rank is None: + ids_rank = array_ops.rank(ids) + static = False + x_rank = x.get_shape().ndims + if x_rank is None: + x_rank = array_ops.rank(x) + static = False + return clip_ops.clip_by_norm( + x, max_norm, + axes=list(range(ids_rank, x_rank)) if static + else math_ops.range(ids_rank, x_rank)) + with ops.name_scope(name, "embedding_lookup", params + [ids]) as name: np = len(params) # Number of partitions # Preserve the resource variable status to avoid accidental dense reads.