fix unit test bug

This commit is contained in:
Ziming Dong 2017-03-23 03:43:46 -04:00
parent fa751ecd3a
commit f00b5c9b2e

View File

@ -611,7 +611,8 @@ class EmbeddingLookupSparseTest(test.TestCase):
p,
sp_ids,
None if ignore_weights else sp_weights,
combiner=combiner)
combiner=combiner,
use_aggregation=use_aggregation)
self.assertEqual(embedding_sum.get_shape().as_list(),
expected_lookup_result_shape)
@ -650,7 +651,8 @@ class EmbeddingLookupSparseTest(test.TestCase):
x,
sp_ids,
None if ignore_weights else sp_weights,
combiner=combiner)
combiner=combiner,
use_aggregation=use_aggregation)
x_name = [_PName(i) for i in range(num_shards)]
x_init_value = [params[x_n + ":0"] for x_n in x_name]
x_shape = [i.shape for i in x_init_value]