Fixit for sequence feature column test.

PiperOrigin-RevId: 324697411
Change-Id: Idd0568a80b4b3f82c5c676920a592154efdcc604
This commit is contained in:
Zhenyu Tan 2020-08-03 15:34:22 -07:00 committed by TensorFlower Gardener
parent f18d09553b
commit a0f7e214ae

View File

@ -516,7 +516,6 @@ class SequenceEmbeddingColumnTest(
class SequenceSharedEmbeddingColumnTest(test.TestCase):
@test_util.run_deprecated_v1
def test_get_sequence_dense_tensor(self):
vocabulary_size = 3
embedding_dimension = 2
@ -532,67 +531,68 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase):
self.assertIsNone(partition_info)
return embedding_values
sparse_input_a = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
# example 2, ids []
# example 3, ids [1]
indices=((0, 0), (1, 0), (1, 1), (3, 0)),
values=(2, 0, 1, 1),
dense_shape=(4, 2))
sparse_input_b = sparse_tensor.SparseTensorValue(
# example 0, ids [1]
# example 1, ids [0, 2]
# example 2, ids [0]
# example 3, ids []
indices=((0, 0), (1, 0), (1, 1), (2, 0)),
values=(1, 0, 2, 0),
dense_shape=(4, 2))
with ops.Graph().as_default():
sparse_input_a = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1]
# example 2, ids []
# example 3, ids [1]
indices=((0, 0), (1, 0), (1, 1), (3, 0)),
values=(2, 0, 1, 1),
dense_shape=(4, 2))
sparse_input_b = sparse_tensor.SparseTensorValue(
# example 0, ids [1]
# example 1, ids [0, 2]
# example 2, ids [0]
# example 3, ids []
indices=((0, 0), (1, 0), (1, 1), (2, 0)),
values=(1, 0, 2, 0),
dense_shape=(4, 2))
expected_lookups_a = [
# example 0, ids [2]
[[7., 11.], [0., 0.]],
# example 1, ids [0, 1]
[[1., 2.], [3., 5.]],
# example 2, ids []
[[0., 0.], [0., 0.]],
# example 3, ids [1]
[[3., 5.], [0., 0.]],
]
expected_lookups_a = [
# example 0, ids [2]
[[7., 11.], [0., 0.]],
# example 1, ids [0, 1]
[[1., 2.], [3., 5.]],
# example 2, ids []
[[0., 0.], [0., 0.]],
# example 3, ids [1]
[[3., 5.], [0., 0.]],
]
expected_lookups_b = [
# example 0, ids [1]
[[3., 5.], [0., 0.]],
# example 1, ids [0, 2]
[[1., 2.], [7., 11.]],
# example 2, ids [0]
[[1., 2.], [0., 0.]],
# example 3, ids []
[[0., 0.], [0., 0.]],
]
expected_lookups_b = [
# example 0, ids [1]
[[3., 5.], [0., 0.]],
# example 1, ids [0, 2]
[[1., 2.], [7., 11.]],
# example 2, ids [0]
[[1., 2.], [0., 0.]],
# example 3, ids []
[[0., 0.], [0., 0.]],
]
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
shared_embedding_columns = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
initializer=_initializer)
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
shared_embedding_columns = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
initializer=_initializer)
embedding_lookup_a = _get_sequence_dense_tensor(
shared_embedding_columns[0], {'aaa': sparse_input_a})[0]
embedding_lookup_b = _get_sequence_dense_tensor(
shared_embedding_columns[1], {'bbb': sparse_input_b})[0]
embedding_lookup_a = _get_sequence_dense_tensor(
shared_embedding_columns[0], {'aaa': sparse_input_a})[0]
embedding_lookup_b = _get_sequence_dense_tensor(
shared_embedding_columns[1], {'bbb': sparse_input_b})[0]
self.evaluate(variables_lib.global_variables_initializer())
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(('aaa_bbb_shared_embedding:0',),
tuple([v.name for v in global_vars]))
self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
self.assertAllEqual(
expected_lookups_a, self.evaluate(embedding_lookup_a))
self.assertAllEqual(expected_lookups_b, self.evaluate(embedding_lookup_b))
self.evaluate(variables_lib.global_variables_initializer())
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(('aaa_bbb_shared_embedding:0',),
tuple([v.name for v in global_vars]))
self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
self.assertAllEqual(
expected_lookups_a, self.evaluate(embedding_lookup_a))
self.assertAllEqual(expected_lookups_b, self.evaluate(embedding_lookup_b))
def test_sequence_length(self):
with ops.Graph().as_default():