317 lines
13 KiB
Python
317 lines
13 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ===================================================================
|
|
"""Tests for python.tpu.feature_column."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.python.client import session
|
|
from tensorflow.python.feature_column import feature_column as fc
|
|
from tensorflow.python.feature_column import feature_column_lib as fc_lib
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import sparse_tensor
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import lookup_ops
|
|
from tensorflow.python.ops import parsing_ops
|
|
from tensorflow.python.ops import variables as variables_lib
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.tpu import feature_column as tpu_fc
|
|
|
|
|
|
def _initialized_session():
|
|
sess = session.Session()
|
|
sess.run(variables_lib.global_variables_initializer())
|
|
sess.run(lookup_ops.tables_initializer())
|
|
return sess
|
|
|
|
|
|
class EmbeddingColumnTest(test.TestCase):
|
|
|
|
def test_defaults(self):
|
|
categorical_column = fc_lib.categorical_column_with_identity(
|
|
key='aaa', num_buckets=3)
|
|
embedding_dimension = 2
|
|
embedding_column = tpu_fc.embedding_column(
|
|
categorical_column, dimension=embedding_dimension)
|
|
self.assertIs(categorical_column, embedding_column.categorical_column)
|
|
self.assertEqual(embedding_dimension, embedding_column.dimension)
|
|
self.assertEqual('mean', embedding_column.combiner)
|
|
self.assertEqual('aaa_embedding', embedding_column.name)
|
|
self.assertEqual('aaa_embedding', embedding_column._var_scope_name)
|
|
self.assertEqual((embedding_dimension,), embedding_column._variable_shape)
|
|
self.assertEqual({
|
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
|
}, embedding_column._parse_example_spec)
|
|
|
|
def test_denylisted_column(self):
|
|
# HashedCategoricalColumn is denylisted and so will raise an exception.
|
|
categorical_column = fc_lib.categorical_column_with_hash_bucket(
|
|
key='aaa', hash_bucket_size=3)
|
|
embedding_dimension = 2
|
|
with self.assertRaises(TypeError):
|
|
tpu_fc.embedding_column(categorical_column, dimension=embedding_dimension)
|
|
|
|
def test_custom_column(self):
|
|
# This column is not in any allowlist but should succeed because
|
|
# it inherits from V2 CategoricalColumn.
|
|
categorical_column = fc_lib.categorical_column_with_identity(
|
|
key='aaa', num_buckets=10)
|
|
embedding_dimension = 2
|
|
embedding_column = tpu_fc.embedding_column(
|
|
categorical_column, dimension=embedding_dimension)
|
|
self.assertIs(categorical_column, embedding_column.categorical_column)
|
|
self.assertEqual(embedding_dimension, embedding_column.dimension)
|
|
self.assertEqual('mean', embedding_column.combiner)
|
|
self.assertEqual('aaa_embedding', embedding_column.name)
|
|
self.assertEqual('aaa_embedding', embedding_column._var_scope_name)
|
|
self.assertEqual((embedding_dimension,), embedding_column._variable_shape)
|
|
self.assertEqual({'aaa': parsing_ops.VarLenFeature(dtypes.int64)},
|
|
embedding_column._parse_example_spec)
|
|
|
|
def test_all_constructor_args(self):
|
|
categorical_column = fc_lib.categorical_column_with_identity(
|
|
key='aaa', num_buckets=3)
|
|
embedding_dimension = 2
|
|
embedding_column = tpu_fc.embedding_column(
|
|
categorical_column,
|
|
dimension=embedding_dimension,
|
|
combiner='my_combiner',
|
|
initializer=lambda: 'my_initializer')
|
|
self.assertIs(categorical_column, embedding_column.categorical_column)
|
|
self.assertEqual(embedding_dimension, embedding_column.dimension)
|
|
self.assertEqual('my_combiner', embedding_column.combiner)
|
|
self.assertEqual('aaa_embedding', embedding_column.name)
|
|
self.assertEqual('aaa_embedding', embedding_column._var_scope_name)
|
|
self.assertEqual((embedding_dimension,), embedding_column._variable_shape)
|
|
self.assertEqual({
|
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
|
}, embedding_column._parse_example_spec)
|
|
|
|
@test_util.deprecated_graph_mode_only
|
|
def test_get_dense_tensor(self):
|
|
# Inputs.
|
|
vocabulary_size = 3
|
|
sparse_input = sparse_tensor.SparseTensorValue(
|
|
# example 0, ids [2]
|
|
# example 1, ids [0, 1]
|
|
# example 2, ids []
|
|
# example 3, ids [1]
|
|
indices=((0, 0), (1, 0), (1, 4), (3, 0)),
|
|
values=(2, 0, 1, 1),
|
|
dense_shape=(4, 5))
|
|
|
|
# Embedding variable.
|
|
embedding_dimension = 2
|
|
embedding_values = (
|
|
(1., 2.), # id 0
|
|
(3., 5.), # id 1
|
|
(7., 11.) # id 2
|
|
)
|
|
|
|
def _initializer(shape, dtype, partition_info):
|
|
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
|
self.assertEqual(dtypes.float32, dtype)
|
|
self.assertIsNone(partition_info)
|
|
return embedding_values
|
|
|
|
# Expected lookup result, using combiner='mean'.
|
|
expected_lookups = (
|
|
# example 0, ids [2], embedding = [7, 11]
|
|
(7., 11.),
|
|
# example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
|
|
(2., 3.5),
|
|
# example 2, ids [], embedding = [0, 0]
|
|
(0., 0.),
|
|
# example 3, ids [1], embedding = [3, 5]
|
|
(3., 5.),
|
|
)
|
|
|
|
# Build columns.
|
|
categorical_column = fc_lib.categorical_column_with_identity(
|
|
key='aaa', num_buckets=vocabulary_size)
|
|
embedding_column = tpu_fc.embedding_column(
|
|
categorical_column,
|
|
dimension=embedding_dimension,
|
|
initializer=_initializer)
|
|
|
|
# Provide sparse input and get dense result.
|
|
embedding_lookup = embedding_column._get_dense_tensor(
|
|
fc._LazyBuilder({
|
|
'aaa': sparse_input
|
|
}))
|
|
|
|
# Assert expected embedding variable and lookups.
|
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
|
self.assertItemsEqual(('embedding_weights:0',),
|
|
tuple([v.name for v in global_vars]))
|
|
with _initialized_session():
|
|
self.assertAllEqual(embedding_values, global_vars[0])
|
|
self.assertAllEqual(expected_lookups, embedding_lookup)
|
|
|
|
|
|
class SharedEmbeddingColumnTest(test.TestCase):
|
|
|
|
@test_util.deprecated_graph_mode_only
|
|
def test_defaults(self):
|
|
categorical_column_a = fc_lib.categorical_column_with_identity(
|
|
key='aaa', num_buckets=3)
|
|
categorical_column_b = fc_lib.categorical_column_with_identity(
|
|
key='bbb', num_buckets=3)
|
|
embedding_dimension = 2
|
|
embedding_column_b, embedding_column_a = tpu_fc.shared_embedding_columns(
|
|
[categorical_column_b, categorical_column_a],
|
|
dimension=embedding_dimension)
|
|
self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
|
|
self.assertIs(categorical_column_b, embedding_column_b.categorical_column)
|
|
self.assertEqual(embedding_dimension, embedding_column_a.dimension)
|
|
self.assertEqual(embedding_dimension, embedding_column_b.dimension)
|
|
self.assertEqual('mean', embedding_column_a.combiner)
|
|
self.assertEqual('mean', embedding_column_b.combiner)
|
|
self.assertIsNotNone(embedding_column_a.initializer)
|
|
self.assertIsNotNone(embedding_column_b.initializer)
|
|
self.assertEqual('aaa_bbb_shared_embedding',
|
|
embedding_column_a.shared_embedding_collection_name)
|
|
self.assertEqual('aaa_bbb_shared_embedding',
|
|
embedding_column_b.shared_embedding_collection_name)
|
|
self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
|
|
self.assertEqual('bbb_shared_embedding', embedding_column_b.name)
|
|
self.assertEqual('aaa_bbb_shared_embedding',
|
|
embedding_column_a._var_scope_name)
|
|
self.assertEqual('aaa_bbb_shared_embedding',
|
|
embedding_column_b._var_scope_name)
|
|
self.assertEqual((embedding_dimension,), embedding_column_a._variable_shape)
|
|
self.assertEqual((embedding_dimension,), embedding_column_b._variable_shape)
|
|
self.assertEqual({
|
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
|
}, embedding_column_a._parse_example_spec)
|
|
self.assertEqual({
|
|
'bbb': parsing_ops.VarLenFeature(dtypes.int64)
|
|
}, embedding_column_b._parse_example_spec)
|
|
|
|
@test_util.deprecated_graph_mode_only
|
|
def test_all_constructor_args(self):
|
|
categorical_column_a = fc_lib.categorical_column_with_identity(
|
|
key='aaa', num_buckets=3)
|
|
categorical_column_b = fc_lib.categorical_column_with_identity(
|
|
key='bbb', num_buckets=3)
|
|
embedding_dimension = 2
|
|
embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns(
|
|
[categorical_column_a, categorical_column_b],
|
|
dimension=embedding_dimension,
|
|
combiner='my_combiner',
|
|
initializer=lambda: 'my_initializer',
|
|
shared_embedding_collection_name='var_scope_name')
|
|
self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
|
|
self.assertIs(categorical_column_b, embedding_column_b.categorical_column)
|
|
self.assertEqual(embedding_dimension, embedding_column_a.dimension)
|
|
self.assertEqual(embedding_dimension, embedding_column_b.dimension)
|
|
self.assertEqual('my_combiner', embedding_column_a.combiner)
|
|
self.assertEqual('my_combiner', embedding_column_b.combiner)
|
|
self.assertEqual('my_initializer', embedding_column_a.initializer())
|
|
self.assertEqual('my_initializer', embedding_column_b.initializer())
|
|
self.assertEqual('var_scope_name',
|
|
embedding_column_a.shared_embedding_collection_name)
|
|
self.assertEqual('var_scope_name',
|
|
embedding_column_b.shared_embedding_collection_name)
|
|
self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
|
|
self.assertEqual('bbb_shared_embedding', embedding_column_b.name)
|
|
self.assertEqual('var_scope_name', embedding_column_a._var_scope_name)
|
|
self.assertEqual('var_scope_name', embedding_column_b._var_scope_name)
|
|
self.assertEqual((embedding_dimension,), embedding_column_a._variable_shape)
|
|
self.assertEqual((embedding_dimension,), embedding_column_b._variable_shape)
|
|
self.assertEqual({
|
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
|
}, embedding_column_a._parse_example_spec)
|
|
self.assertEqual({
|
|
'bbb': parsing_ops.VarLenFeature(dtypes.int64)
|
|
}, embedding_column_b._parse_example_spec)
|
|
|
|
@test_util.deprecated_graph_mode_only
|
|
def test_get_dense_tensor(self):
|
|
# Inputs.
|
|
vocabulary_size = 3
|
|
# -1 values are ignored.
|
|
input_a = np.array([
|
|
[2, -1, -1], # example 0, ids [2]
|
|
[0, 1, -1]
|
|
]) # example 1, ids [0, 1]
|
|
input_b = np.array([
|
|
[0, -1, -1], # example 0, ids [0]
|
|
[-1, -1, -1]
|
|
]) # example 1, ids []
|
|
input_features = {'aaa': input_a, 'bbb': input_b}
|
|
|
|
# Embedding variable.
|
|
embedding_dimension = 2
|
|
embedding_values = (
|
|
(1., 2.), # id 0
|
|
(3., 5.), # id 1
|
|
(7., 11.) # id 2
|
|
)
|
|
|
|
def _initializer(shape, dtype, partition_info):
|
|
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
|
self.assertEqual(dtypes.float32, dtype)
|
|
self.assertIsNone(partition_info)
|
|
return embedding_values
|
|
|
|
# Expected lookup result, using combiner='mean'.
|
|
expected_lookups_a = (
|
|
# example 0:
|
|
(7., 11.), # ids [2], embedding = [7, 11]
|
|
# example 1:
|
|
(2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
|
|
)
|
|
expected_lookups_b = (
|
|
# example 0:
|
|
(1., 2.), # ids [0], embedding = [1, 2]
|
|
# example 1:
|
|
(0., 0.), # ids [], embedding = [0, 0]
|
|
)
|
|
|
|
# Build columns.
|
|
categorical_column_a = fc_lib.categorical_column_with_identity(
|
|
key='aaa', num_buckets=vocabulary_size)
|
|
categorical_column_b = fc_lib.categorical_column_with_identity(
|
|
key='bbb', num_buckets=vocabulary_size)
|
|
embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns(
|
|
[categorical_column_a, categorical_column_b],
|
|
dimension=embedding_dimension,
|
|
initializer=_initializer)
|
|
|
|
# Provide sparse input and get dense result.
|
|
embedding_lookup_a = embedding_column_a._get_dense_tensor(
|
|
fc._LazyBuilder(input_features))
|
|
embedding_lookup_b = embedding_column_b._get_dense_tensor(
|
|
fc._LazyBuilder(input_features))
|
|
|
|
# Assert expected embedding variable and lookups.
|
|
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
|
self.assertItemsEqual(('embedding_weights:0',),
|
|
tuple([v.name for v in global_vars]))
|
|
embedding_var = global_vars[0]
|
|
with _initialized_session():
|
|
self.assertAllEqual(embedding_values, embedding_var)
|
|
self.assertAllEqual(expected_lookups_a, embedding_lookup_a)
|
|
self.assertAllEqual(expected_lookups_b, embedding_lookup_b)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test.main()
|