Merge pull request #21677 from markdryan:markdryan-avx512-fix-embedding_ops_test
PiperOrigin-RevId: 236360578
This commit is contained in:
commit
5e29089ffc
@ -584,7 +584,13 @@ class EmbeddingLookupTest(test.TestCase):
|
||||
# Compare nonsharded to gather
|
||||
simple = embedding_ops.embedding_lookup(
|
||||
params, ids, max_norm=1.0).eval()
|
||||
self.assertAllEqual(simple, array_ops.gather(params_norm, ids).eval())
|
||||
# assertAllClose is used here as different implementations of sqrt may
|
||||
# be used to compute each of the values being compared. For example,
|
||||
# on AVX512 builds the embedding operation makes use of Eigen's fast
|
||||
# vectorized square root algorithm for doubles. These different
|
||||
# implementations of sqrt are not guaranteed to produce exactly the
|
||||
# same results. Therefore, an exact comparison cannot be made.
|
||||
self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval())
|
||||
# Run a few different sharded versions.
|
||||
for procs in 1, 2, 3:
|
||||
stride = procs * math_ops.range(params.shape[0] // procs)
|
||||
@ -630,7 +636,13 @@ class EmbeddingLookupTest(test.TestCase):
|
||||
sharded = embedding_ops._embedding_lookup_and_transform(
|
||||
split_params, ids, max_norm=l2_norm,
|
||||
transform_fn=transform).eval()
|
||||
self.assertAllEqual(simple, sharded)
|
||||
# assertAllClose is used here as different implementations of sqrt may
|
||||
# be used to compute each of the values being compared. For example,
|
||||
# on AVX512 builds the embedding operation makes use of Eigen's fast
|
||||
# vectorized square root algorithm for doubles. These different
|
||||
# implementations of sqrt are not guaranteed to produce exactly the
|
||||
# same results. Therefore, an exact comparison cannot be made.
|
||||
self.assertAllClose(simple, sharded)
|
||||
|
||||
|
||||
class EmbeddingLookupSparseTest(test.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user