diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 33bed3abcf1..ffdf8868e21 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -121,6 +121,8 @@ from __future__ import print_function import abc import collections +import numpy as np + from tensorflow.python.feature_column import lookup_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -433,6 +435,12 @@ def bucketized_column(source_column, boundaries): return _BucketizedColumn(source_column, tuple(boundaries)) +def _assert_string_or_int(dtype, prefix): + if (dtype != dtypes.string) and (not dtype.is_integer): + raise ValueError( + '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype)) + + def categorical_column_with_hash_bucket(key, hash_bucket_size, dtype=dtypes.string): @@ -475,9 +483,7 @@ def categorical_column_with_hash_bucket(key, 'hash_bucket_size: {}, key: {}'.format( hash_bucket_size, key)) - if dtype != dtypes.string and not dtype.is_integer: - raise ValueError('dtype must be string or integer. ' - 'dtype: {}, column_name: {}'.format(dtype, key)) + _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) return _HashedCategoricalColumn(key, hash_bucket_size, dtype) @@ -485,7 +491,7 @@ def categorical_column_with_hash_bucket(key, def categorical_column_with_vocabulary_file( key, vocabulary_file, vocabulary_size, num_oov_buckets=0, default_value=None, dtype=dtypes.string): - """Creates a `_CategoricalColumn` with vocabulary file configuration. + """A `_CategoricalColumn` with a vocabulary file. Use this when your inputs are in string or integer format, and you have a vocabulary file that maps each value to an integer ID. By default, @@ -504,7 +510,7 @@ def categorical_column_with_vocabulary_file( ID 50-54. ```python states = categorical_column_with_vocabulary_file( - key='keywords', vocabulary_file='/us/states.txt', vocabulary_size=50, + key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, num_oov_buckets=5) linear_prediction = make_linear_model(features, [states, ...]) ``` @@ -516,7 +522,7 @@ def categorical_column_with_vocabulary_file( others are assigned the corresponding line number 1-50. ```python states = categorical_column_with_vocabulary_file( - key='keywords', vocabulary_file='/us/states.txt', vocabulary_size=51, + key='states', vocabulary_file='/us/states.txt', vocabulary_size=51, default_value=0) linear_prediction, _, _ = make_linear_model(features, [states, ...]) @@ -530,7 +536,9 @@ def categorical_column_with_vocabulary_file( column name and the dictionary key for feature parsing configs, feature `Tensor` objects, and feature columns. vocabulary_file: The vocabulary file name. - vocabulary_size: Number of the elements in the vocabulary. + vocabulary_size: Number of the elements in the vocabulary. This must be no + greater than length of `vocabulary_file`, if less than length, later + values are ignored. num_oov_buckets: Non-negative integer, the number of out-of-vocabulary buckets. All out-of-vocabulary inputs will be assigned IDs in the range `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of @@ -542,7 +550,7 @@ def categorical_column_with_vocabulary_file( dtype: The type of features. Only string and integer types are supported. Returns: - A `_CategoricalColumn` with vocabulary file configuration. + A `_CategoricalColumn` with a vocabulary file. Raises: ValueError: `vocabulary_file` is missing. @@ -564,9 +572,8 @@ def categorical_column_with_vocabulary_file( if num_oov_buckets < 0: raise ValueError('Invalid num_oov_buckets {} in {}.'.format( num_oov_buckets, key)) - if dtype != dtypes.string and not dtype.is_integer: - raise ValueError('Invalid dtype {} in {}.'.format(dtype, key)) - return _VocabularyCategoricalColumn( + _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + return _VocabularyFileCategoricalColumn( key=key, vocabulary_file=vocabulary_file, vocabulary_size=vocabulary_size, @@ -575,6 +582,80 @@ def categorical_column_with_vocabulary_file( dtype=dtype) +def categorical_column_with_vocabulary_list( + key, vocabulary_list, dtype=None, default_value=-1): + """A `_CategoricalColumn` with in-memory vocabulary. + + Logic for feature f is: + id = f in vocabulary_list ? vocabulary_list.index(f) : default_value + + Use this when your inputs are in string or integer format, and you have an + in-memory vocabulary mapping each value to an integer ID. By default, + out-of-vocabulary values are ignored. Use `default_value` to specify how to + include out-of-vocabulary values. + + Inputs can be either `Tensor` or `SparseTensor`. If `Tensor`, missing values + can be represented by `-1` for int and `''` for string. Note that these values + are independent of the `default_value` argument. + + In the following examples, each input in `vocabulary_list` is assigned an ID + 0-4 corresponding to its index (e.g., input 'B' produces output 2). All other + inputs are assigned `default_value` 0. + + Linear model: + ```python + colors = categorical_column_with_vocabulary_list( + key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0) + linear_prediction, _, _ = make_linear_model(features, [colors, ...]) + ``` + + Embedding for a DNN model: + ```python + dense_tensor = make_input_layer(features, [embedding_column(colors, 3),...]) + ``` + + Args: + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. + vocabulary_list: An ordered iterable defining the vocabulary. Each feature + is mapped to the index of its value (if present) in `vocabulary_list`. + Must be castable to `dtype`. + dtype: The type of features. Only string and integer types are supported. + If `None`, it will be inferred from `vocabulary_list`. + default_value: The value to use for values not in `vocabulary_list`. + + Returns: + A `_CategoricalColumn` with in-memory vocabulary. + + Raises: + ValueError: if `vocabulary_list` is empty, or contains duplicate keys. + ValueError: if `dtype` is not integer or string. + """ + if (vocabulary_list is None) or (len(vocabulary_list) < 1): + raise ValueError( + 'vocabulary_list {} must be non-empty, column_name: {}'.format( + vocabulary_list, key)) + if len(set(vocabulary_list)) != len(vocabulary_list): + raise ValueError( + 'Duplicate keys in vocabulary_list {}, column_name: {}'.format( + vocabulary_list, key)) + vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype) + _assert_string_or_int( + vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key)) + if dtype is None: + dtype = vocabulary_dtype + elif dtype.is_integer != vocabulary_dtype.is_integer: + raise ValueError( + 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format( + dtype, vocabulary_dtype, key)) + _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + + return _VocabularyListCategoricalColumn( + key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype, + default_value=default_value) + + class _FeatureColumn(object): """Represents a feature column abstraction. @@ -1170,11 +1251,9 @@ class _HashedCategoricalColumn( if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): raise ValueError('SparseColumn input must be a SparseTensor.') - if (input_tensor.dtype != dtypes.string and - not input_tensor.dtype.is_integer): - raise ValueError('input tensors dtype must be string or integer. ' - 'dtype: {}, column_name: {}'.format( - input_tensor.dtype, self.key)) + _assert_string_or_int( + input_tensor.dtype, + prefix='column_name: {} input_tensor'.format(self.key)) if self.dtype.is_integer != input_tensor.dtype.is_integer: raise ValueError( @@ -1202,8 +1281,9 @@ class _HashedCategoricalColumn( return _CategoricalColumn.IdWeightPair(inputs.get(self), None) -class _VocabularyCategoricalColumn( - _CategoricalColumn, collections.namedtuple('_VocabularyCategoricalColumn', ( +class _VocabularyFileCategoricalColumn( + _CategoricalColumn, + collections.namedtuple('_VocabularyFileCategoricalColumn', ( 'key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'dtype', 'default_value' ))): @@ -1226,15 +1306,15 @@ class _VocabularyCategoricalColumn( 'key: {}, column dtype: {}, tensor dtype: {}'.format( self.key, self.dtype, input_tensor.dtype)) + _assert_string_or_int( + input_tensor.dtype, + prefix='column_name: {} input_tensor'.format(self.key)) + key_dtype = self.dtype if input_tensor.dtype.is_integer: # `index_table_from_file` requires 64-bit integer keys. key_dtype = dtypes.int64 input_tensor = math_ops.to_int64(input_tensor) - elif input_tensor.dtype != dtypes.string: - raise ValueError('input tensors dtype must be string or integer. ' - 'dtype: {}, column_name: {}'.format( - input_tensor.dtype, self.key)) return lookup_ops.index_table_from_file( vocabulary_file=self.vocabulary_file, @@ -1254,6 +1334,56 @@ class _VocabularyCategoricalColumn( return _CategoricalColumn.IdWeightPair(inputs.get(self), None) +class _VocabularyListCategoricalColumn( + _CategoricalColumn, + collections.namedtuple('_VocabularyListCategoricalColumn', ( + 'key', 'vocabulary_list', 'dtype', 'default_value' + ))): + """See `categorical_column_with_vocabulary_list`.""" + + @property + def name(self): + return self.key + + @property + def _parse_example_config(self): + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + input_tensor = _to_sparse_input(inputs.get(self.key)) + + if self.dtype.is_integer != input_tensor.dtype.is_integer: + raise ValueError( + 'Column dtype and SparseTensors dtype must be compatible. ' + 'key: {}, column dtype: {}, tensor dtype: {}'.format( + self.key, self.dtype, input_tensor.dtype)) + + _assert_string_or_int( + input_tensor.dtype, + prefix='column_name: {} input_tensor'.format(self.key)) + + key_dtype = self.dtype + if input_tensor.dtype.is_integer: + # `index_table_from_tensor` requires 64-bit integer keys. + key_dtype = dtypes.int64 + input_tensor = math_ops.to_int64(input_tensor) + + return lookup_ops.index_table_from_tensor( + mapping=tuple(self.vocabulary_list), + default_value=self.default_value, + dtype=key_dtype, + name='{}_lookup'.format(self.key)).lookup(input_tensor) + + @property + def _num_buckets(self): + """Returns number of buckets in this sparse feature.""" + return len(self.vocabulary_list) + + def _get_sparse_tensors( + self, inputs, weight_collections=None, trainable=None): + return _CategoricalColumn.IdWeightPair(inputs.get(self), None) + + # TODO(zakaria): Move this to embedding_ops and make it public. def _safe_embedding_lookup_sparse(embedding_weights, sparse_ids, diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index ad67a082dc9..59aa39411f5 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -1193,10 +1193,22 @@ class MakeInputLayerTest(test.TestCase): self.assertAllClose([[1., 3.]], net2.eval()) -class VocabularyCategoricalColumnTest(test.TestCase): +def _assert_sparse_tensor_value(test_case, expected, actual): + test_case.assertEqual(np.int64, np.array(actual.indices).dtype) + test_case.assertAllEqual(expected.indices, actual.indices) + + test_case.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + test_case.assertAllEqual(expected.values, actual.values) + + test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) + + +class VocabularyFileCategoricalColumnTest(test.TestCase): def setUp(self): - super(VocabularyCategoricalColumnTest, self).setUp() + super(VocabularyFileCategoricalColumnTest, self).setUp() # Contains ints, Golden State Warriors jersey numbers: 30, 35, 11, 23, 22 self._warriors_vocabulary_file_name = test.test_src_dir_path( @@ -1208,17 +1220,6 @@ class VocabularyCategoricalColumnTest(test.TestCase): 'python/feature_column/testdata/wire_vocabulary.txt') self._wire_vocabulary_size = 3 - def _assert_sparse_tensor_value(self, expected, actual): - self.assertEqual(np.int64, np.array(actual.indices).dtype) - self.assertAllEqual(expected.indices, actual.indices) - - self.assertEqual( - np.array(expected.values).dtype, np.array(actual.values).dtype) - self.assertAllEqual(expected.values, actual.values) - - self.assertEqual(np.int64, np.array(actual.dense_shape).dtype) - self.assertAllEqual(expected.dense_shape, actual.dense_shape) - def test_defaults(self): column = fc.categorical_column_with_vocabulary_file( key='aaa', vocabulary_file='path_to_file', vocabulary_size=3) @@ -1316,7 +1317,7 @@ class VocabularyCategoricalColumnTest(test.TestCase): num_oov_buckets=-1) def test_invalid_dtype(self): - with self.assertRaisesRegexp(ValueError, 'Invalid dtype'): + with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'): fc.categorical_column_with_vocabulary_file( key='aaa', vocabulary_file='path', vocabulary_size=3, dtype=dtypes.float64) @@ -1331,6 +1332,36 @@ class VocabularyCategoricalColumnTest(test.TestCase): num_oov_buckets=100, default_value=2) + def test_invalid_input_dtype_int32(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + dtype=dtypes.string) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(12, 24, 36), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + + def test_invalid_input_dtype_string(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + def test_get_sparse_tensors(self): column = fc.categorical_column_with_vocabulary_file( key='aaa', @@ -1346,7 +1377,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=inputs.indices, values=np.array((2, -1, 0), dtype=np.int64), @@ -1365,7 +1397,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=((0, 0), (1, 0), (1, 1)), values=np.array((2, -1, 0), dtype=np.int64), @@ -1388,7 +1421,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=inputs.indices, values=np.array((2, 2, 0), dtype=np.int64), @@ -1411,7 +1445,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=inputs.indices, values=np.array((2, 33, 0, 62), dtype=np.int64), @@ -1436,7 +1471,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=inputs.indices, values=np.array((-1, -1, 0), dtype=np.int64), @@ -1459,7 +1495,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=inputs.indices, values=np.array((2, -1, 0, 4), dtype=np.int64), @@ -1481,7 +1518,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=((0, 0), (1, 0), (1, 1), (2, 2)), values=np.array((2, default_value, 0, 4), dtype=np.int64), @@ -1505,7 +1543,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=inputs.indices, values=np.array((2, 60, 0, 4), dtype=np.int64), @@ -1538,5 +1577,256 @@ class VocabularyCategoricalColumnTest(test.TestCase): self.assertAllClose(((3.,), (5.,)), predictions.eval()) +class VocabularyListCategoricalColumnTest(test.TestCase): + + def test_defaults_string(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.string) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_defaults_int(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36)) + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_all_constructor_args(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32, + default_value=-99) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_deep_copy(self): + """Tests deepcopy of categorical_column_with_hash_bucket.""" + original = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32) + for column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_invalid_dtype(self): + with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'), + dtype=dtypes.float32) + + def test_invalid_mapping_dtype(self): + with self.assertRaisesRegexp( + ValueError, r'vocabulary dtype must be string or integer'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12., 24., 36.)) + + def test_mismatched_int_dtype(self): + with self.assertRaisesRegexp( + ValueError, r'dtype.*and vocabulary dtype.*do not match'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'), + dtype=dtypes.int32) + + def test_mismatched_string_dtype(self): + with self.assertRaisesRegexp( + ValueError, r'dtype.*and vocabulary dtype.*do not match'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.string) + + def test_none_mapping(self): + with self.assertRaisesRegexp( + ValueError, r'vocabulary_list.*must be non-empty'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=None) + + def test_empty_mapping(self): + with self.assertRaisesRegexp( + ValueError, r'vocabulary_list.*must be non-empty'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=tuple([])) + + def test_duplicate_mapping(self): + with self.assertRaisesRegexp(ValueError, 'Duplicate keys'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 12)) + + def test_invalid_input_dtype_int32(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(12, 24, 36), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + + def test_invalid_input_dtype_string(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=(12, 24, 36)) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + + def test_get_sparse_tensors(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_dense_input(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ + 'aaa': (('marlo', ''), ('skywalker', 'omar')) + })) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_default_value_in_vocabulary(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo'), + default_value=2) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 2, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32), + dtype=dtypes.int32) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=np.array((11, 100, 30, 22), dtype=np.int32), + dense_shape=(3, 3)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0, 4), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32_dense_input(self): + default_value = -100 + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32), + dtype=dtypes.int32, + default_value=default_value) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ + 'aaa': np.array( + ((11, -1, -1), (100, 30, -1), (-1, -1, 22)), + dtype=np.int32) + })) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=np.array((2, default_value, 0, 4), dtype=np.int64), + dense_shape=(3, 3)), + id_weight_pair.id_tensor.eval()) + + def test_make_linear_model(self): + wire_column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + self.assertEqual(3, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = fc.make_linear_model({ + wire_column.name: sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,))).eval() + # 'marlo' -> 2: wire_var[2] = 3 + # 'skywalker' -> None, 'omar' -> 0: wire_var[0] = 1 + self.assertAllClose(((3.,), (1.,)), predictions.eval()) + + if __name__ == '__main__': test.main()