Merge pull request #21677 from markdryan:markdryan-avx512-fix-embedding_ops_test
PiperOrigin-RevId: 236360578
This commit is contained in:
commit
5e29089ffc
@ -584,7 +584,13 @@ class EmbeddingLookupTest(test.TestCase):
|
|||||||
# Compare nonsharded to gather
|
# Compare nonsharded to gather
|
||||||
simple = embedding_ops.embedding_lookup(
|
simple = embedding_ops.embedding_lookup(
|
||||||
params, ids, max_norm=1.0).eval()
|
params, ids, max_norm=1.0).eval()
|
||||||
self.assertAllEqual(simple, array_ops.gather(params_norm, ids).eval())
|
# assertAllClose is used here as different implementations of sqrt may
|
||||||
|
# be used to compute each of the values being compared. For example,
|
||||||
|
# on AVX512 builds the embedding operation makes use of Eigen's fast
|
||||||
|
# vectorized square root algorithm for doubles. These different
|
||||||
|
# implementations of sqrt are not guaranteed to produce exactly the
|
||||||
|
# same results. Therefore, an exact comparison cannot be made.
|
||||||
|
self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval())
|
||||||
# Run a few different sharded versions.
|
# Run a few different sharded versions.
|
||||||
for procs in 1, 2, 3:
|
for procs in 1, 2, 3:
|
||||||
stride = procs * math_ops.range(params.shape[0] // procs)
|
stride = procs * math_ops.range(params.shape[0] // procs)
|
||||||
@ -630,7 +636,13 @@ class EmbeddingLookupTest(test.TestCase):
|
|||||||
sharded = embedding_ops._embedding_lookup_and_transform(
|
sharded = embedding_ops._embedding_lookup_and_transform(
|
||||||
split_params, ids, max_norm=l2_norm,
|
split_params, ids, max_norm=l2_norm,
|
||||||
transform_fn=transform).eval()
|
transform_fn=transform).eval()
|
||||||
self.assertAllEqual(simple, sharded)
|
# assertAllClose is used here as different implementations of sqrt may
|
||||||
|
# be used to compute each of the values being compared. For example,
|
||||||
|
# on AVX512 builds the embedding operation makes use of Eigen's fast
|
||||||
|
# vectorized square root algorithm for doubles. These different
|
||||||
|
# implementations of sqrt are not guaranteed to produce exactly the
|
||||||
|
# same results. Therefore, an exact comparison cannot be made.
|
||||||
|
self.assertAllClose(simple, sharded)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingLookupSparseTest(test.TestCase):
|
class EmbeddingLookupSparseTest(test.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user