fix unit test bug
This commit is contained in:
parent
fa751ecd3a
commit
f00b5c9b2e
@ -611,7 +611,8 @@ class EmbeddingLookupSparseTest(test.TestCase):
|
||||
p,
|
||||
sp_ids,
|
||||
None if ignore_weights else sp_weights,
|
||||
combiner=combiner)
|
||||
combiner=combiner,
|
||||
use_aggregation=use_aggregation)
|
||||
|
||||
self.assertEqual(embedding_sum.get_shape().as_list(),
|
||||
expected_lookup_result_shape)
|
||||
@ -650,7 +651,8 @@ class EmbeddingLookupSparseTest(test.TestCase):
|
||||
x,
|
||||
sp_ids,
|
||||
None if ignore_weights else sp_weights,
|
||||
combiner=combiner)
|
||||
combiner=combiner,
|
||||
use_aggregation=use_aggregation)
|
||||
x_name = [_PName(i) for i in range(num_shards)]
|
||||
x_init_value = [params[x_n + ":0"] for x_n in x_name]
|
||||
x_shape = [i.shape for i in x_init_value]
|
||||
|
Loading…
Reference in New Issue
Block a user