Fix embedding_lookup() bug where normalization did not work with ids of rank != 1.
PiperOrigin-RevId: 157422220
This commit is contained in:
parent
8cad6b824e
commit
822d64f0c6
@ -547,6 +547,31 @@ class EmbeddingLookupTest(test.TestCase):
|
||||
sharded = embedding_ops.embedding_lookup(split_params, ids).eval()
|
||||
self.assertAllEqual(simple, sharded)
|
||||
|
||||
def testHigherRankMaxNorm(self):
|
||||
np.random.seed(8)
|
||||
with self.test_session():
|
||||
for params_shape in (12,), (6, 3):
|
||||
params = 2 * np.ones(params_shape)
|
||||
params_norm = params / np.sqrt(
|
||||
np.sum(params*params, tuple(range(params.ndim)[1:]), keepdims=True))
|
||||
for ids_shape in (), (3), (4, 3), (2, 3, 4):
|
||||
ids = np.random.randint(
|
||||
params.shape[0], size=np.prod(ids_shape, dtype=np.int64)).reshape(
|
||||
ids_shape)
|
||||
# Compare nonsharded to gather
|
||||
simple = embedding_ops.embedding_lookup(
|
||||
params, ids, max_norm=1.0).eval()
|
||||
self.assertAllEqual(simple, array_ops.gather(params_norm, ids).eval())
|
||||
# Run a few random sharded versions
|
||||
for procs in 1, 2, 3:
|
||||
stride = procs * math_ops.range(params.shape[0] // procs)
|
||||
split_params = [
|
||||
array_ops.gather(params, stride + p) for p in xrange(procs)
|
||||
]
|
||||
sharded = embedding_ops.embedding_lookup(
|
||||
split_params, ids, max_norm=1.0).eval()
|
||||
self.assertAllEqual(simple, sharded)
|
||||
|
||||
|
||||
class EmbeddingLookupSparseTest(test.TestCase):
|
||||
|
||||
|
@ -103,14 +103,25 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
|
||||
params = list(params) # Iterate to get the underlying Variables.
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
|
||||
def maybe_normalize(x):
|
||||
if max_norm is not None:
|
||||
if x.get_shape().ndims is not None:
|
||||
ndims = x.get_shape().ndims
|
||||
else:
|
||||
ndims = array_ops.size(array_ops.shape(x))
|
||||
return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
|
||||
return x
|
||||
"""Normalizes the embeddings in x if max_norm is not None."""
|
||||
if max_norm is None:
|
||||
return x
|
||||
static = True
|
||||
ids_rank = ops.convert_to_tensor(ids).get_shape().ndims
|
||||
if ids_rank is None:
|
||||
ids_rank = array_ops.rank(ids)
|
||||
static = False
|
||||
x_rank = x.get_shape().ndims
|
||||
if x_rank is None:
|
||||
x_rank = array_ops.rank(x)
|
||||
static = False
|
||||
return clip_ops.clip_by_norm(
|
||||
x, max_norm,
|
||||
axes=list(range(ids_rank, x_rank)) if static
|
||||
else math_ops.range(ids_rank, x_rank))
|
||||
|
||||
with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
|
||||
np = len(params) # Number of partitions
|
||||
# Preserve the resource variable status to avoid accidental dense reads.
|
||||
|
Loading…
Reference in New Issue
Block a user