Add categorical_column_with_vocabulary_list
.
Change: 155158042
This commit is contained in:
parent
d48f3a9a3f
commit
afd69fc26f
@ -121,6 +121,8 @@ from __future__ import print_function
|
|||||||
import abc
|
import abc
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.feature_column import lookup_ops
|
from tensorflow.python.feature_column import lookup_ops
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -433,6 +435,12 @@ def bucketized_column(source_column, boundaries):
|
|||||||
return _BucketizedColumn(source_column, tuple(boundaries))
|
return _BucketizedColumn(source_column, tuple(boundaries))
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_string_or_int(dtype, prefix):
|
||||||
|
if (dtype != dtypes.string) and (not dtype.is_integer):
|
||||||
|
raise ValueError(
|
||||||
|
'{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype))
|
||||||
|
|
||||||
|
|
||||||
def categorical_column_with_hash_bucket(key,
|
def categorical_column_with_hash_bucket(key,
|
||||||
hash_bucket_size,
|
hash_bucket_size,
|
||||||
dtype=dtypes.string):
|
dtype=dtypes.string):
|
||||||
@ -475,9 +483,7 @@ def categorical_column_with_hash_bucket(key,
|
|||||||
'hash_bucket_size: {}, key: {}'.format(
|
'hash_bucket_size: {}, key: {}'.format(
|
||||||
hash_bucket_size, key))
|
hash_bucket_size, key))
|
||||||
|
|
||||||
if dtype != dtypes.string and not dtype.is_integer:
|
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
||||||
raise ValueError('dtype must be string or integer. '
|
|
||||||
'dtype: {}, column_name: {}'.format(dtype, key))
|
|
||||||
|
|
||||||
return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
|
return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
|
||||||
|
|
||||||
@ -485,7 +491,7 @@ def categorical_column_with_hash_bucket(key,
|
|||||||
def categorical_column_with_vocabulary_file(
|
def categorical_column_with_vocabulary_file(
|
||||||
key, vocabulary_file, vocabulary_size, num_oov_buckets=0,
|
key, vocabulary_file, vocabulary_size, num_oov_buckets=0,
|
||||||
default_value=None, dtype=dtypes.string):
|
default_value=None, dtype=dtypes.string):
|
||||||
"""Creates a `_CategoricalColumn` with vocabulary file configuration.
|
"""A `_CategoricalColumn` with a vocabulary file.
|
||||||
|
|
||||||
Use this when your inputs are in string or integer format, and you have a
|
Use this when your inputs are in string or integer format, and you have a
|
||||||
vocabulary file that maps each value to an integer ID. By default,
|
vocabulary file that maps each value to an integer ID. By default,
|
||||||
@ -504,7 +510,7 @@ def categorical_column_with_vocabulary_file(
|
|||||||
ID 50-54.
|
ID 50-54.
|
||||||
```python
|
```python
|
||||||
states = categorical_column_with_vocabulary_file(
|
states = categorical_column_with_vocabulary_file(
|
||||||
key='keywords', vocabulary_file='/us/states.txt', vocabulary_size=50,
|
key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
|
||||||
num_oov_buckets=5)
|
num_oov_buckets=5)
|
||||||
linear_prediction = make_linear_model(features, [states, ...])
|
linear_prediction = make_linear_model(features, [states, ...])
|
||||||
```
|
```
|
||||||
@ -516,7 +522,7 @@ def categorical_column_with_vocabulary_file(
|
|||||||
others are assigned the corresponding line number 1-50.
|
others are assigned the corresponding line number 1-50.
|
||||||
```python
|
```python
|
||||||
states = categorical_column_with_vocabulary_file(
|
states = categorical_column_with_vocabulary_file(
|
||||||
key='keywords', vocabulary_file='/us/states.txt', vocabulary_size=51,
|
key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
|
||||||
default_value=0)
|
default_value=0)
|
||||||
linear_prediction, _, _ = make_linear_model(features, [states, ...])
|
linear_prediction, _, _ = make_linear_model(features, [states, ...])
|
||||||
|
|
||||||
@ -530,7 +536,9 @@ def categorical_column_with_vocabulary_file(
|
|||||||
column name and the dictionary key for feature parsing configs, feature
|
column name and the dictionary key for feature parsing configs, feature
|
||||||
`Tensor` objects, and feature columns.
|
`Tensor` objects, and feature columns.
|
||||||
vocabulary_file: The vocabulary file name.
|
vocabulary_file: The vocabulary file name.
|
||||||
vocabulary_size: Number of the elements in the vocabulary.
|
vocabulary_size: Number of the elements in the vocabulary. This must be no
|
||||||
|
greater than length of `vocabulary_file`, if less than length, later
|
||||||
|
values are ignored.
|
||||||
num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
|
num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
|
||||||
buckets. All out-of-vocabulary inputs will be assigned IDs in the range
|
buckets. All out-of-vocabulary inputs will be assigned IDs in the range
|
||||||
`[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
|
`[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
|
||||||
@ -542,7 +550,7 @@ def categorical_column_with_vocabulary_file(
|
|||||||
dtype: The type of features. Only string and integer types are supported.
|
dtype: The type of features. Only string and integer types are supported.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `_CategoricalColumn` with vocabulary file configuration.
|
A `_CategoricalColumn` with a vocabulary file.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: `vocabulary_file` is missing.
|
ValueError: `vocabulary_file` is missing.
|
||||||
@ -564,9 +572,8 @@ def categorical_column_with_vocabulary_file(
|
|||||||
if num_oov_buckets < 0:
|
if num_oov_buckets < 0:
|
||||||
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
|
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
|
||||||
num_oov_buckets, key))
|
num_oov_buckets, key))
|
||||||
if dtype != dtypes.string and not dtype.is_integer:
|
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
||||||
raise ValueError('Invalid dtype {} in {}.'.format(dtype, key))
|
return _VocabularyFileCategoricalColumn(
|
||||||
return _VocabularyCategoricalColumn(
|
|
||||||
key=key,
|
key=key,
|
||||||
vocabulary_file=vocabulary_file,
|
vocabulary_file=vocabulary_file,
|
||||||
vocabulary_size=vocabulary_size,
|
vocabulary_size=vocabulary_size,
|
||||||
@ -575,6 +582,80 @@ def categorical_column_with_vocabulary_file(
|
|||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def categorical_column_with_vocabulary_list(
|
||||||
|
key, vocabulary_list, dtype=None, default_value=-1):
|
||||||
|
"""A `_CategoricalColumn` with in-memory vocabulary.
|
||||||
|
|
||||||
|
Logic for feature f is:
|
||||||
|
id = f in vocabulary_list ? vocabulary_list.index(f) : default_value
|
||||||
|
|
||||||
|
Use this when your inputs are in string or integer format, and you have an
|
||||||
|
in-memory vocabulary mapping each value to an integer ID. By default,
|
||||||
|
out-of-vocabulary values are ignored. Use `default_value` to specify how to
|
||||||
|
include out-of-vocabulary values.
|
||||||
|
|
||||||
|
Inputs can be either `Tensor` or `SparseTensor`. If `Tensor`, missing values
|
||||||
|
can be represented by `-1` for int and `''` for string. Note that these values
|
||||||
|
are independent of the `default_value` argument.
|
||||||
|
|
||||||
|
In the following examples, each input in `vocabulary_list` is assigned an ID
|
||||||
|
0-4 corresponding to its index (e.g., input 'B' produces output 2). All other
|
||||||
|
inputs are assigned `default_value` 0.
|
||||||
|
|
||||||
|
Linear model:
|
||||||
|
```python
|
||||||
|
colors = categorical_column_with_vocabulary_list(
|
||||||
|
key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0)
|
||||||
|
linear_prediction, _, _ = make_linear_model(features, [colors, ...])
|
||||||
|
```
|
||||||
|
|
||||||
|
Embedding for a DNN model:
|
||||||
|
```python
|
||||||
|
dense_tensor = make_input_layer(features, [embedding_column(colors, 3),...])
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: A unique string identifying the input feature. It is used as the
|
||||||
|
column name and the dictionary key for feature parsing configs, feature
|
||||||
|
`Tensor` objects, and feature columns.
|
||||||
|
vocabulary_list: An ordered iterable defining the vocabulary. Each feature
|
||||||
|
is mapped to the index of its value (if present) in `vocabulary_list`.
|
||||||
|
Must be castable to `dtype`.
|
||||||
|
dtype: The type of features. Only string and integer types are supported.
|
||||||
|
If `None`, it will be inferred from `vocabulary_list`.
|
||||||
|
default_value: The value to use for values not in `vocabulary_list`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `_CategoricalColumn` with in-memory vocabulary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
|
||||||
|
ValueError: if `dtype` is not integer or string.
|
||||||
|
"""
|
||||||
|
if (vocabulary_list is None) or (len(vocabulary_list) < 1):
|
||||||
|
raise ValueError(
|
||||||
|
'vocabulary_list {} must be non-empty, column_name: {}'.format(
|
||||||
|
vocabulary_list, key))
|
||||||
|
if len(set(vocabulary_list)) != len(vocabulary_list):
|
||||||
|
raise ValueError(
|
||||||
|
'Duplicate keys in vocabulary_list {}, column_name: {}'.format(
|
||||||
|
vocabulary_list, key))
|
||||||
|
vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype)
|
||||||
|
_assert_string_or_int(
|
||||||
|
vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
|
||||||
|
if dtype is None:
|
||||||
|
dtype = vocabulary_dtype
|
||||||
|
elif dtype.is_integer != vocabulary_dtype.is_integer:
|
||||||
|
raise ValueError(
|
||||||
|
'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
|
||||||
|
dtype, vocabulary_dtype, key))
|
||||||
|
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
|
||||||
|
|
||||||
|
return _VocabularyListCategoricalColumn(
|
||||||
|
key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype,
|
||||||
|
default_value=default_value)
|
||||||
|
|
||||||
|
|
||||||
class _FeatureColumn(object):
|
class _FeatureColumn(object):
|
||||||
"""Represents a feature column abstraction.
|
"""Represents a feature column abstraction.
|
||||||
|
|
||||||
@ -1170,11 +1251,9 @@ class _HashedCategoricalColumn(
|
|||||||
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
|
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
|
||||||
raise ValueError('SparseColumn input must be a SparseTensor.')
|
raise ValueError('SparseColumn input must be a SparseTensor.')
|
||||||
|
|
||||||
if (input_tensor.dtype != dtypes.string and
|
_assert_string_or_int(
|
||||||
not input_tensor.dtype.is_integer):
|
input_tensor.dtype,
|
||||||
raise ValueError('input tensors dtype must be string or integer. '
|
prefix='column_name: {} input_tensor'.format(self.key))
|
||||||
'dtype: {}, column_name: {}'.format(
|
|
||||||
input_tensor.dtype, self.key))
|
|
||||||
|
|
||||||
if self.dtype.is_integer != input_tensor.dtype.is_integer:
|
if self.dtype.is_integer != input_tensor.dtype.is_integer:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1202,8 +1281,9 @@ class _HashedCategoricalColumn(
|
|||||||
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
||||||
|
|
||||||
|
|
||||||
class _VocabularyCategoricalColumn(
|
class _VocabularyFileCategoricalColumn(
|
||||||
_CategoricalColumn, collections.namedtuple('_VocabularyCategoricalColumn', (
|
_CategoricalColumn,
|
||||||
|
collections.namedtuple('_VocabularyFileCategoricalColumn', (
|
||||||
'key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'dtype',
|
'key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'dtype',
|
||||||
'default_value'
|
'default_value'
|
||||||
))):
|
))):
|
||||||
@ -1226,15 +1306,15 @@ class _VocabularyCategoricalColumn(
|
|||||||
'key: {}, column dtype: {}, tensor dtype: {}'.format(
|
'key: {}, column dtype: {}, tensor dtype: {}'.format(
|
||||||
self.key, self.dtype, input_tensor.dtype))
|
self.key, self.dtype, input_tensor.dtype))
|
||||||
|
|
||||||
|
_assert_string_or_int(
|
||||||
|
input_tensor.dtype,
|
||||||
|
prefix='column_name: {} input_tensor'.format(self.key))
|
||||||
|
|
||||||
key_dtype = self.dtype
|
key_dtype = self.dtype
|
||||||
if input_tensor.dtype.is_integer:
|
if input_tensor.dtype.is_integer:
|
||||||
# `index_table_from_file` requires 64-bit integer keys.
|
# `index_table_from_file` requires 64-bit integer keys.
|
||||||
key_dtype = dtypes.int64
|
key_dtype = dtypes.int64
|
||||||
input_tensor = math_ops.to_int64(input_tensor)
|
input_tensor = math_ops.to_int64(input_tensor)
|
||||||
elif input_tensor.dtype != dtypes.string:
|
|
||||||
raise ValueError('input tensors dtype must be string or integer. '
|
|
||||||
'dtype: {}, column_name: {}'.format(
|
|
||||||
input_tensor.dtype, self.key))
|
|
||||||
|
|
||||||
return lookup_ops.index_table_from_file(
|
return lookup_ops.index_table_from_file(
|
||||||
vocabulary_file=self.vocabulary_file,
|
vocabulary_file=self.vocabulary_file,
|
||||||
@ -1254,6 +1334,56 @@ class _VocabularyCategoricalColumn(
|
|||||||
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
||||||
|
|
||||||
|
|
||||||
|
class _VocabularyListCategoricalColumn(
|
||||||
|
_CategoricalColumn,
|
||||||
|
collections.namedtuple('_VocabularyListCategoricalColumn', (
|
||||||
|
'key', 'vocabulary_list', 'dtype', 'default_value'
|
||||||
|
))):
|
||||||
|
"""See `categorical_column_with_vocabulary_list`."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _parse_example_config(self):
|
||||||
|
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
|
||||||
|
|
||||||
|
def _transform_feature(self, inputs):
|
||||||
|
input_tensor = _to_sparse_input(inputs.get(self.key))
|
||||||
|
|
||||||
|
if self.dtype.is_integer != input_tensor.dtype.is_integer:
|
||||||
|
raise ValueError(
|
||||||
|
'Column dtype and SparseTensors dtype must be compatible. '
|
||||||
|
'key: {}, column dtype: {}, tensor dtype: {}'.format(
|
||||||
|
self.key, self.dtype, input_tensor.dtype))
|
||||||
|
|
||||||
|
_assert_string_or_int(
|
||||||
|
input_tensor.dtype,
|
||||||
|
prefix='column_name: {} input_tensor'.format(self.key))
|
||||||
|
|
||||||
|
key_dtype = self.dtype
|
||||||
|
if input_tensor.dtype.is_integer:
|
||||||
|
# `index_table_from_tensor` requires 64-bit integer keys.
|
||||||
|
key_dtype = dtypes.int64
|
||||||
|
input_tensor = math_ops.to_int64(input_tensor)
|
||||||
|
|
||||||
|
return lookup_ops.index_table_from_tensor(
|
||||||
|
mapping=tuple(self.vocabulary_list),
|
||||||
|
default_value=self.default_value,
|
||||||
|
dtype=key_dtype,
|
||||||
|
name='{}_lookup'.format(self.key)).lookup(input_tensor)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _num_buckets(self):
|
||||||
|
"""Returns number of buckets in this sparse feature."""
|
||||||
|
return len(self.vocabulary_list)
|
||||||
|
|
||||||
|
def _get_sparse_tensors(
|
||||||
|
self, inputs, weight_collections=None, trainable=None):
|
||||||
|
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
||||||
|
|
||||||
|
|
||||||
# TODO(zakaria): Move this to embedding_ops and make it public.
|
# TODO(zakaria): Move this to embedding_ops and make it public.
|
||||||
def _safe_embedding_lookup_sparse(embedding_weights,
|
def _safe_embedding_lookup_sparse(embedding_weights,
|
||||||
sparse_ids,
|
sparse_ids,
|
||||||
|
@ -1193,10 +1193,22 @@ class MakeInputLayerTest(test.TestCase):
|
|||||||
self.assertAllClose([[1., 3.]], net2.eval())
|
self.assertAllClose([[1., 3.]], net2.eval())
|
||||||
|
|
||||||
|
|
||||||
class VocabularyCategoricalColumnTest(test.TestCase):
|
def _assert_sparse_tensor_value(test_case, expected, actual):
|
||||||
|
test_case.assertEqual(np.int64, np.array(actual.indices).dtype)
|
||||||
|
test_case.assertAllEqual(expected.indices, actual.indices)
|
||||||
|
|
||||||
|
test_case.assertEqual(
|
||||||
|
np.array(expected.values).dtype, np.array(actual.values).dtype)
|
||||||
|
test_case.assertAllEqual(expected.values, actual.values)
|
||||||
|
|
||||||
|
test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype)
|
||||||
|
test_case.assertAllEqual(expected.dense_shape, actual.dense_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class VocabularyFileCategoricalColumnTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(VocabularyCategoricalColumnTest, self).setUp()
|
super(VocabularyFileCategoricalColumnTest, self).setUp()
|
||||||
|
|
||||||
# Contains ints, Golden State Warriors jersey numbers: 30, 35, 11, 23, 22
|
# Contains ints, Golden State Warriors jersey numbers: 30, 35, 11, 23, 22
|
||||||
self._warriors_vocabulary_file_name = test.test_src_dir_path(
|
self._warriors_vocabulary_file_name = test.test_src_dir_path(
|
||||||
@ -1208,17 +1220,6 @@ class VocabularyCategoricalColumnTest(test.TestCase):
|
|||||||
'python/feature_column/testdata/wire_vocabulary.txt')
|
'python/feature_column/testdata/wire_vocabulary.txt')
|
||||||
self._wire_vocabulary_size = 3
|
self._wire_vocabulary_size = 3
|
||||||
|
|
||||||
def _assert_sparse_tensor_value(self, expected, actual):
|
|
||||||
self.assertEqual(np.int64, np.array(actual.indices).dtype)
|
|
||||||
self.assertAllEqual(expected.indices, actual.indices)
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
np.array(expected.values).dtype, np.array(actual.values).dtype)
|
|
||||||
self.assertAllEqual(expected.values, actual.values)
|
|
||||||
|
|
||||||
self.assertEqual(np.int64, np.array(actual.dense_shape).dtype)
|
|
||||||
self.assertAllEqual(expected.dense_shape, actual.dense_shape)
|
|
||||||
|
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
|
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
|
||||||
@ -1316,7 +1317,7 @@ class VocabularyCategoricalColumnTest(test.TestCase):
|
|||||||
num_oov_buckets=-1)
|
num_oov_buckets=-1)
|
||||||
|
|
||||||
def test_invalid_dtype(self):
|
def test_invalid_dtype(self):
|
||||||
with self.assertRaisesRegexp(ValueError, 'Invalid dtype'):
|
with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
|
||||||
fc.categorical_column_with_vocabulary_file(
|
fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa', vocabulary_file='path', vocabulary_size=3,
|
key='aaa', vocabulary_file='path', vocabulary_size=3,
|
||||||
dtype=dtypes.float64)
|
dtype=dtypes.float64)
|
||||||
@ -1331,6 +1332,36 @@ class VocabularyCategoricalColumnTest(test.TestCase):
|
|||||||
num_oov_buckets=100,
|
num_oov_buckets=100,
|
||||||
default_value=2)
|
default_value=2)
|
||||||
|
|
||||||
|
def test_invalid_input_dtype_int32(self):
|
||||||
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
|
key='aaa',
|
||||||
|
vocabulary_file=self._wire_vocabulary_file_name,
|
||||||
|
vocabulary_size=self._wire_vocabulary_size,
|
||||||
|
dtype=dtypes.string)
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=(12, 24, 36),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
def test_invalid_input_dtype_string(self):
|
||||||
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
|
key='aaa',
|
||||||
|
vocabulary_file=self._warriors_vocabulary_file_name,
|
||||||
|
vocabulary_size=self._warriors_vocabulary_size,
|
||||||
|
dtype=dtypes.int32)
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=('omar', 'stringer', 'marlo'),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
def test_get_sparse_tensors(self):
|
def test_get_sparse_tensors(self):
|
||||||
column = fc.categorical_column_with_vocabulary_file(
|
column = fc.categorical_column_with_vocabulary_file(
|
||||||
key='aaa',
|
key='aaa',
|
||||||
@ -1346,7 +1377,8 @@ class VocabularyCategoricalColumnTest(test.TestCase):
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self._assert_sparse_tensor_value(
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
sparse_tensor.SparseTensorValue(
|
sparse_tensor.SparseTensorValue(
|
||||||
indices=inputs.indices,
|
indices=inputs.indices,
|
||||||
values=np.array((2, -1, 0), dtype=np.int64),
|
values=np.array((2, -1, 0), dtype=np.int64),
|
||||||
@ -1365,7 +1397,8 @@ class VocabularyCategoricalColumnTest(test.TestCase):
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self._assert_sparse_tensor_value(
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
sparse_tensor.SparseTensorValue(
|
sparse_tensor.SparseTensorValue(
|
||||||
indices=((0, 0), (1, 0), (1, 1)),
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
values=np.array((2, -1, 0), dtype=np.int64),
|
values=np.array((2, -1, 0), dtype=np.int64),
|
||||||
@ -1388,7 +1421,8 @@ class VocabularyCategoricalColumnTest(test.TestCase):
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self._assert_sparse_tensor_value(
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
sparse_tensor.SparseTensorValue(
|
sparse_tensor.SparseTensorValue(
|
||||||
indices=inputs.indices,
|
indices=inputs.indices,
|
||||||
values=np.array((2, 2, 0), dtype=np.int64),
|
values=np.array((2, 2, 0), dtype=np.int64),
|
||||||
@ -1411,7 +1445,8 @@ class VocabularyCategoricalColumnTest(test.TestCase):
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self._assert_sparse_tensor_value(
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
sparse_tensor.SparseTensorValue(
|
sparse_tensor.SparseTensorValue(
|
||||||
indices=inputs.indices,
|
indices=inputs.indices,
|
||||||
values=np.array((2, 33, 0, 62), dtype=np.int64),
|
values=np.array((2, 33, 0, 62), dtype=np.int64),
|
||||||
@ -1436,7 +1471,8 @@ class VocabularyCategoricalColumnTest(test.TestCase):
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self._assert_sparse_tensor_value(
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
sparse_tensor.SparseTensorValue(
|
sparse_tensor.SparseTensorValue(
|
||||||
indices=inputs.indices,
|
indices=inputs.indices,
|
||||||
values=np.array((-1, -1, 0), dtype=np.int64),
|
values=np.array((-1, -1, 0), dtype=np.int64),
|
||||||
@ -1459,7 +1495,8 @@ class VocabularyCategoricalColumnTest(test.TestCase):
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self._assert_sparse_tensor_value(
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
sparse_tensor.SparseTensorValue(
|
sparse_tensor.SparseTensorValue(
|
||||||
indices=inputs.indices,
|
indices=inputs.indices,
|
||||||
values=np.array((2, -1, 0, 4), dtype=np.int64),
|
values=np.array((2, -1, 0, 4), dtype=np.int64),
|
||||||
@ -1481,7 +1518,8 @@ class VocabularyCategoricalColumnTest(test.TestCase):
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self._assert_sparse_tensor_value(
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
sparse_tensor.SparseTensorValue(
|
sparse_tensor.SparseTensorValue(
|
||||||
indices=((0, 0), (1, 0), (1, 1), (2, 2)),
|
indices=((0, 0), (1, 0), (1, 1), (2, 2)),
|
||||||
values=np.array((2, default_value, 0, 4), dtype=np.int64),
|
values=np.array((2, default_value, 0, 4), dtype=np.int64),
|
||||||
@ -1505,7 +1543,8 @@ class VocabularyCategoricalColumnTest(test.TestCase):
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
with _initialized_session():
|
with _initialized_session():
|
||||||
self._assert_sparse_tensor_value(
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
sparse_tensor.SparseTensorValue(
|
sparse_tensor.SparseTensorValue(
|
||||||
indices=inputs.indices,
|
indices=inputs.indices,
|
||||||
values=np.array((2, 60, 0, 4), dtype=np.int64),
|
values=np.array((2, 60, 0, 4), dtype=np.int64),
|
||||||
@ -1538,5 +1577,256 @@ class VocabularyCategoricalColumnTest(test.TestCase):
|
|||||||
self.assertAllClose(((3.,), (5.,)), predictions.eval())
|
self.assertAllClose(((3.,), (5.,)), predictions.eval())
|
||||||
|
|
||||||
|
|
||||||
|
class VocabularyListCategoricalColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_defaults_string(self):
|
||||||
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
|
self.assertEqual('aaa', column.name)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
self.assertEqual(3, column._num_buckets)
|
||||||
|
self.assertEqual({
|
||||||
|
'aaa': parsing_ops.VarLenFeature(dtypes.string)
|
||||||
|
}, column._parse_example_config)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
def test_defaults_int(self):
|
||||||
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa', vocabulary_list=(12, 24, 36))
|
||||||
|
self.assertEqual('aaa', column.name)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
self.assertEqual(3, column._num_buckets)
|
||||||
|
self.assertEqual({
|
||||||
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
|
}, column._parse_example_config)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
def test_all_constructor_args(self):
|
||||||
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32,
|
||||||
|
default_value=-99)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
self.assertEqual(3, column._num_buckets)
|
||||||
|
self.assertEqual({
|
||||||
|
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
||||||
|
}, column._parse_example_config)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
def test_deep_copy(self):
|
||||||
|
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
||||||
|
original = fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32)
|
||||||
|
for column in (original, copy.deepcopy(original)):
|
||||||
|
self.assertEqual('aaa', column.name)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
self.assertEqual(3, column._num_buckets)
|
||||||
|
self.assertEqual({
|
||||||
|
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
||||||
|
}, column._parse_example_config)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
def test_invalid_dtype(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
|
||||||
|
fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'),
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
|
||||||
|
def test_invalid_mapping_dtype(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, r'vocabulary dtype must be string or integer'):
|
||||||
|
fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa', vocabulary_list=(12., 24., 36.))
|
||||||
|
|
||||||
|
def test_mismatched_int_dtype(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, r'dtype.*and vocabulary dtype.*do not match'):
|
||||||
|
fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'),
|
||||||
|
dtype=dtypes.int32)
|
||||||
|
|
||||||
|
def test_mismatched_string_dtype(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, r'dtype.*and vocabulary dtype.*do not match'):
|
||||||
|
fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.string)
|
||||||
|
|
||||||
|
def test_none_mapping(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, r'vocabulary_list.*must be non-empty'):
|
||||||
|
fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa', vocabulary_list=None)
|
||||||
|
|
||||||
|
def test_empty_mapping(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, r'vocabulary_list.*must be non-empty'):
|
||||||
|
fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa', vocabulary_list=tuple([]))
|
||||||
|
|
||||||
|
def test_duplicate_mapping(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Duplicate keys'):
|
||||||
|
fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa', vocabulary_list=(12, 24, 12))
|
||||||
|
|
||||||
|
def test_invalid_input_dtype_int32(self):
|
||||||
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa',
|
||||||
|
vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=(12, 24, 36),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
def test_invalid_input_dtype_string(self):
|
||||||
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa',
|
||||||
|
vocabulary_list=(12, 24, 36))
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=('omar', 'stringer', 'marlo'),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
def test_get_sparse_tensors(self):
|
||||||
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa',
|
||||||
|
vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=('marlo', 'skywalker', 'omar'),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
id_weight_pair = column._get_sparse_tensors(
|
||||||
|
fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
|
with _initialized_session():
|
||||||
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
|
sparse_tensor.SparseTensorValue(
|
||||||
|
indices=inputs.indices,
|
||||||
|
values=np.array((2, -1, 0), dtype=np.int64),
|
||||||
|
dense_shape=inputs.dense_shape),
|
||||||
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa',
|
||||||
|
vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
|
||||||
|
'aaa': (('marlo', ''), ('skywalker', 'omar'))
|
||||||
|
}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
|
with _initialized_session():
|
||||||
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
|
sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=np.array((2, -1, 0), dtype=np.int64),
|
||||||
|
dense_shape=(2, 2)),
|
||||||
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
def test_get_sparse_tensors_default_value_in_vocabulary(self):
|
||||||
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa',
|
||||||
|
vocabulary_list=('omar', 'stringer', 'marlo'),
|
||||||
|
default_value=2)
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=('marlo', 'skywalker', 'omar'),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
id_weight_pair = column._get_sparse_tensors(
|
||||||
|
fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
|
with _initialized_session():
|
||||||
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
|
sparse_tensor.SparseTensorValue(
|
||||||
|
indices=inputs.indices,
|
||||||
|
values=np.array((2, 2, 0), dtype=np.int64),
|
||||||
|
dense_shape=inputs.dense_shape),
|
||||||
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
def test_get_sparse_tensors_int32(self):
|
||||||
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa',
|
||||||
|
vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
|
||||||
|
dtype=dtypes.int32)
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1), (2, 2)),
|
||||||
|
values=np.array((11, 100, 30, 22), dtype=np.int32),
|
||||||
|
dense_shape=(3, 3))
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
id_weight_pair = column._get_sparse_tensors(
|
||||||
|
fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
|
with _initialized_session():
|
||||||
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
|
sparse_tensor.SparseTensorValue(
|
||||||
|
indices=inputs.indices,
|
||||||
|
values=np.array((2, -1, 0, 4), dtype=np.int64),
|
||||||
|
dense_shape=inputs.dense_shape),
|
||||||
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
def test_get_sparse_tensors_int32_dense_input(self):
|
||||||
|
default_value = -100
|
||||||
|
column = fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa',
|
||||||
|
vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
|
||||||
|
dtype=dtypes.int32,
|
||||||
|
default_value=default_value)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
|
||||||
|
'aaa': np.array(
|
||||||
|
((11, -1, -1), (100, 30, -1), (-1, -1, 22)),
|
||||||
|
dtype=np.int32)
|
||||||
|
}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
|
with _initialized_session():
|
||||||
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
|
sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1), (2, 2)),
|
||||||
|
values=np.array((2, default_value, 0, 4), dtype=np.int64),
|
||||||
|
dense_shape=(3, 3)),
|
||||||
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
def test_make_linear_model(self):
|
||||||
|
wire_column = fc.categorical_column_with_vocabulary_list(
|
||||||
|
key='aaa',
|
||||||
|
vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||||
|
self.assertEqual(3, wire_column._num_buckets)
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
predictions = fc.make_linear_model({
|
||||||
|
wire_column.name: sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=('marlo', 'skywalker', 'omar'),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
}, (wire_column,))
|
||||||
|
bias = get_linear_model_bias()
|
||||||
|
wire_var = get_linear_model_column_var(wire_column)
|
||||||
|
with _initialized_session():
|
||||||
|
self.assertAllClose((0.,), bias.eval())
|
||||||
|
self.assertAllClose(((0.,), (0.,), (0.,)), wire_var.eval())
|
||||||
|
self.assertAllClose(((0.,), (0.,)), predictions.eval())
|
||||||
|
wire_var.assign(((1.,), (2.,), (3.,))).eval()
|
||||||
|
# 'marlo' -> 2: wire_var[2] = 3
|
||||||
|
# 'skywalker' -> None, 'omar' -> 0: wire_var[0] = 1
|
||||||
|
self.assertAllClose(((3.,), (1.,)), predictions.eval())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user