diff --git a/tensorflow/python/tpu/feature_column.py b/tensorflow/python/tpu/feature_column.py index 8a6e71b4baa..cf4a9095567 100644 --- a/tensorflow/python/tpu/feature_column.py +++ b/tensorflow/python/tpu/feature_column.py @@ -34,11 +34,15 @@ _TPU_FC_TO_SCOPE = '_tpu_feature_column_scope' _SUPPORTED_SEQUENCE_COLUMNS = (fc._SequenceCategoricalColumn, fc_lib.SequenceCategoricalColumn) -_SUPPORTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.IdentityCategoricalColumn, - fc_lib.VocabularyFileCategoricalColumn, - fc_lib.VocabularyListCategoricalColumn, - fc_lib.WeightedCategoricalColumn, - fc_lib.SequenceCategoricalColumn) + +# For V2 columns, we support anything that inherits from CategoricalColumn +# other than those in the blacklist. User-provided columns that inherit from +# CategoricalColumn may or may not be compatible; it is up to the user to +# manage TPU compatibility for custom columns. +_SUPPORTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.CategoricalColumn,) +_BLACKLISTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.HashedCategoricalColumn, + fc_lib.BucketizedColumn, + fc_lib.CrossedColumn) _SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn, fc._VocabularyFileCategoricalColumn, fc._VocabularyListCategoricalColumn, @@ -89,7 +93,12 @@ def embedding_column(categorical_column, Raises: ValueError: if `dimension` not > 0. ValueError: if `initializer` is specified but not callable. + TypeError: if categorical_column is not a supported type. """ + if isinstance(categorical_column, _BLACKLISTED_CATEGORICAL_COLUMNS_V2): + raise TypeError('categorical_column for tpu ' + ' embedding_column was blacklisted type %s' % + type(categorical_column)) if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): raise TypeError( 'categorical_column for tpu ' @@ -191,6 +200,10 @@ def shared_embedding_columns(categorical_columns, or 0 for a sequence column. """ for categorical_column in categorical_columns: + if isinstance(categorical_column, _BLACKLISTED_CATEGORICAL_COLUMNS_V2): + raise TypeError('categorical_column for tpu ' + ' embedding_column was blacklisted type %s' % + type(categorical_column)) if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): raise TypeError( 'categorical_column for tpu ' diff --git a/tensorflow/python/tpu/feature_column_test.py b/tensorflow/python/tpu/feature_column_test.py index 99e66de2ba7..9503fb27fb9 100644 --- a/tensorflow/python/tpu/feature_column_test.py +++ b/tensorflow/python/tpu/feature_column_test.py @@ -59,6 +59,31 @@ class EmbeddingColumnTest(test.TestCase): 'aaa': parsing_ops.VarLenFeature(dtypes.int64) }, embedding_column._parse_example_spec) + def test_blacklisted_column(self): + # HashedCategoricalColumn is blacklisted and so will raise an exception. + categorical_column = fc_lib.categorical_column_with_hash_bucket( + key='aaa', hash_bucket_size=3) + embedding_dimension = 2 + with self.assertRaises(TypeError): + tpu_fc.embedding_column(categorical_column, dimension=embedding_dimension) + + def test_custom_column(self): + # This column is not in any whitelist but should succeed because + # it inherits from V2 CategoricalColumn. + categorical_column = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=10) + embedding_dimension = 2 + embedding_column = tpu_fc.embedding_column( + categorical_column, dimension=embedding_dimension) + self.assertIs(categorical_column, embedding_column.categorical_column) + self.assertEqual(embedding_dimension, embedding_column.dimension) + self.assertEqual('mean', embedding_column.combiner) + self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual('aaa_embedding', embedding_column._var_scope_name) + self.assertEqual((embedding_dimension,), embedding_column._variable_shape) + self.assertEqual({'aaa': parsing_ops.VarLenFeature(dtypes.int64)}, + embedding_column._parse_example_spec) + def test_all_constructor_args(self): categorical_column = fc_lib.categorical_column_with_identity( key='aaa', num_buckets=3)