Allow arbitrary v2 CategoricalColumns to be used with TPUEmbeddingColumn

V2 FeatureColumns allow for users to extend the base columns with their own custom columns. The current whitelist does not allow these custom columns to be used.

PiperOrigin-RevId: 275311069
Change-Id: I1ec15a04fbb588826debcabbaf6ceaf74f3a20ca
This commit is contained in:
Amy Skerry-Ryan 2019-10-17 12:15:28 -07:00 committed by TensorFlower Gardener
parent d32472edeb
commit 30741b3911
2 changed files with 43 additions and 5 deletions

View File

@ -34,11 +34,15 @@ _TPU_FC_TO_SCOPE = '_tpu_feature_column_scope'
_SUPPORTED_SEQUENCE_COLUMNS = (fc._SequenceCategoricalColumn,
fc_lib.SequenceCategoricalColumn)
_SUPPORTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.IdentityCategoricalColumn,
fc_lib.VocabularyFileCategoricalColumn,
fc_lib.VocabularyListCategoricalColumn,
fc_lib.WeightedCategoricalColumn,
fc_lib.SequenceCategoricalColumn)
# For V2 columns, we support anything that inherits from CategoricalColumn
# other than those in the blacklist. User-provided columns that inherit from
# CategoricalColumn may or may not be compatible; it is up to the user to
# manage TPU compatibility for custom columns.
_SUPPORTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.CategoricalColumn,)
_BLACKLISTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.HashedCategoricalColumn,
fc_lib.BucketizedColumn,
fc_lib.CrossedColumn)
_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn,
fc._VocabularyFileCategoricalColumn,
fc._VocabularyListCategoricalColumn,
@ -89,7 +93,12 @@ def embedding_column(categorical_column,
Raises:
ValueError: if `dimension` not > 0.
ValueError: if `initializer` is specified but not callable.
TypeError: if categorical_column is not a supported type.
"""
if isinstance(categorical_column, _BLACKLISTED_CATEGORICAL_COLUMNS_V2):
raise TypeError('categorical_column for tpu '
' embedding_column was blacklisted type %s' %
type(categorical_column))
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
raise TypeError(
'categorical_column for tpu '
@ -191,6 +200,10 @@ def shared_embedding_columns(categorical_columns,
or 0 for a sequence column.
"""
for categorical_column in categorical_columns:
if isinstance(categorical_column, _BLACKLISTED_CATEGORICAL_COLUMNS_V2):
raise TypeError('categorical_column for tpu '
' embedding_column was blacklisted type %s' %
type(categorical_column))
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
raise TypeError(
'categorical_column for tpu '

View File

@ -59,6 +59,31 @@ class EmbeddingColumnTest(test.TestCase):
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, embedding_column._parse_example_spec)
def test_blacklisted_column(self):
# HashedCategoricalColumn is blacklisted and so will raise an exception.
categorical_column = fc_lib.categorical_column_with_hash_bucket(
key='aaa', hash_bucket_size=3)
embedding_dimension = 2
with self.assertRaises(TypeError):
tpu_fc.embedding_column(categorical_column, dimension=embedding_dimension)
def test_custom_column(self):
# This column is not in any whitelist but should succeed because
# it inherits from V2 CategoricalColumn.
categorical_column = fc_lib.categorical_column_with_identity(
key='aaa', num_buckets=10)
embedding_dimension = 2
embedding_column = tpu_fc.embedding_column(
categorical_column, dimension=embedding_dimension)
self.assertIs(categorical_column, embedding_column.categorical_column)
self.assertEqual(embedding_dimension, embedding_column.dimension)
self.assertEqual('mean', embedding_column.combiner)
self.assertEqual('aaa_embedding', embedding_column.name)
self.assertEqual('aaa_embedding', embedding_column._var_scope_name)
self.assertEqual((embedding_dimension,), embedding_column._variable_shape)
self.assertEqual({'aaa': parsing_ops.VarLenFeature(dtypes.int64)},
embedding_column._parse_example_spec)
def test_all_constructor_args(self):
categorical_column = fc_lib.categorical_column_with_identity(
key='aaa', num_buckets=3)