Merge pull request #21677 from markdryan:markdryan-avx512-fix-embedding_ops_test

PiperOrigin-RevId: 236360578
This commit is contained in:
TensorFlower Gardener 2019-03-01 12:58:28 -08:00
commit 5e29089ffc

View File

@ -584,7 +584,13 @@ class EmbeddingLookupTest(test.TestCase):
# Compare nonsharded to gather # Compare nonsharded to gather
simple = embedding_ops.embedding_lookup( simple = embedding_ops.embedding_lookup(
params, ids, max_norm=1.0).eval() params, ids, max_norm=1.0).eval()
self.assertAllEqual(simple, array_ops.gather(params_norm, ids).eval()) # assertAllClose is used here as different implementations of sqrt may
# be used to compute each of the values being compared. For example,
# on AVX512 builds the embedding operation makes use of Eigen's fast
# vectorized square root algorithm for doubles. These different
# implementations of sqrt are not guaranteed to produce exactly the
# same results. Therefore, an exact comparison cannot be made.
self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval())
# Run a few different sharded versions. # Run a few different sharded versions.
for procs in 1, 2, 3: for procs in 1, 2, 3:
stride = procs * math_ops.range(params.shape[0] // procs) stride = procs * math_ops.range(params.shape[0] // procs)
@ -630,7 +636,13 @@ class EmbeddingLookupTest(test.TestCase):
sharded = embedding_ops._embedding_lookup_and_transform( sharded = embedding_ops._embedding_lookup_and_transform(
split_params, ids, max_norm=l2_norm, split_params, ids, max_norm=l2_norm,
transform_fn=transform).eval() transform_fn=transform).eval()
self.assertAllEqual(simple, sharded) # assertAllClose is used here as different implementations of sqrt may
# be used to compute each of the values being compared. For example,
# on AVX512 builds the embedding operation makes use of Eigen's fast
# vectorized square root algorithm for doubles. These different
# implementations of sqrt are not guaranteed to produce exactly the
# same results. Therefore, an exact comparison cannot be made.
self.assertAllClose(simple, sharded)
class EmbeddingLookupSparseTest(test.TestCase): class EmbeddingLookupSparseTest(test.TestCase):