Move crossed_column to core.
PiperOrigin-RevId: 155687697
This commit is contained in:
parent
770a27161b
commit
e09b0b6ebf
@ -607,6 +607,18 @@ def bucketized_column(source_column, boundaries):
|
|||||||
dense_tensor = make_input_layer(features, columns)
|
dense_tensor = make_input_layer(features, columns)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`bucketized_column` can also be crossed with another categorical column using
|
||||||
|
`crossed_column`:
|
||||||
|
```python
|
||||||
|
price = numeric_column('price')
|
||||||
|
# bucketized_column converts numerical feature to a categorical one.
|
||||||
|
bucketized_price = bucketized_column(price, boundaries=[...])
|
||||||
|
# 'keywords' is a string feature.
|
||||||
|
price_x_keywords = crossed_column([bucketized_price, 'keywords'], 50K)
|
||||||
|
all_feature_columns = [price_x_keywords, ...]
|
||||||
|
linear_prediction = make_linear_model(features, all_feature_columns)
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_column: A one-dimensional dense column which is generated with
|
source_column: A one-dimensional dense column which is generated with
|
||||||
`numeric_column`.
|
`numeric_column`.
|
||||||
@ -1036,6 +1048,107 @@ def weighted_categorical_column(
|
|||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def crossed_column(keys, hash_bucket_size, hash_key=None):
|
||||||
|
"""Returns a column for performing crosses of categorical features.
|
||||||
|
|
||||||
|
Crossed features will be hashed according to `hash_bucket_size`. Conceptually,
|
||||||
|
the transformation can be thought of as:
|
||||||
|
Hash(cartesian product of features) % `hash_bucket_size`
|
||||||
|
|
||||||
|
For example, if the input features are:
|
||||||
|
* SparseTensor referred by first key: shape = [2, 2]
|
||||||
|
[0, 0]: "a"
|
||||||
|
[1, 0]: "b"
|
||||||
|
[1, 1]: "c"
|
||||||
|
|
||||||
|
* SparseTensor referred by second key: shape = [2, 1]
|
||||||
|
[0, 0]: "d"
|
||||||
|
[1, 0]: "e"
|
||||||
|
|
||||||
|
then crossed feature will look like:
|
||||||
|
shape = [2, 2]
|
||||||
|
[0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size
|
||||||
|
[1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size
|
||||||
|
[1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size
|
||||||
|
|
||||||
|
Here is an example to create a linear model with crosses of string features:
|
||||||
|
```python
|
||||||
|
keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K)
|
||||||
|
all_feature_columns = [keywords_x_doc_terms, ...]
|
||||||
|
linear_prediction = make_linear_model(features, all_feature_columns)
|
||||||
|
```
|
||||||
|
|
||||||
|
You could also use vocabulary lookup before crossing:
|
||||||
|
```python
|
||||||
|
keywords = categorical_column_with_vocabulary_file(
|
||||||
|
'keywords', '/path/to/vocabulary/file', vocabulary_size=1K)
|
||||||
|
keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K)
|
||||||
|
all_feature_columns = [keywords_x_doc_terms, ...]
|
||||||
|
linear_prediction = make_linear_model(features, all_feature_columns)
|
||||||
|
```
|
||||||
|
|
||||||
|
If an input feature is of numeric type, you can use
|
||||||
|
`categorical_column_with_identity`, or `bucketized_column`, as in the example:
|
||||||
|
```python
|
||||||
|
# vertical_id is an integer categorical feature.
|
||||||
|
vertical_id = categorical_column_with_identity('vertical_id', 10K)
|
||||||
|
price = numeric_column('price')
|
||||||
|
# bucketized_column converts numerical feature to a categorical one.
|
||||||
|
bucketized_price = bucketized_column(price, boundaries=[...])
|
||||||
|
vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
|
||||||
|
all_feature_columns = [vertical_id_x_price, ...]
|
||||||
|
linear_prediction = make_linear_model(features, all_feature_columns)
|
||||||
|
```
|
||||||
|
|
||||||
|
To use crossed column in DNN model, you need to add it in an embedding column
|
||||||
|
as in this example:
|
||||||
|
```python
|
||||||
|
vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
|
||||||
|
vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10)
|
||||||
|
dense_tensor = make_input_layer(features, [vertical_id_x_price_embedded, ...])
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys: An iterable identifying the features to be crossed. Each element can
|
||||||
|
be either:
|
||||||
|
* string: Will use the corresponding feature which must be of string type.
|
||||||
|
* `_CategoricalColumn`: Will use the transformed tensor produced by this
|
||||||
|
column. Does not support hashed categorical column.
|
||||||
|
hash_bucket_size: An int > 1. The number of buckets.
|
||||||
|
hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
|
||||||
|
function to combine the crosses fingerprints on SparseCrossOp (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `_CrossedColumn`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `len(keys) < 2`.
|
||||||
|
ValueError: If any of the keys is neither a string nor `_CategoricalColumn`.
|
||||||
|
ValueError: If any of the keys is `_HashedCategoricalColumn`.
|
||||||
|
ValueError: If `hash_bucket_size < 1`.
|
||||||
|
"""
|
||||||
|
if not hash_bucket_size or hash_bucket_size < 1:
|
||||||
|
raise ValueError('hash_bucket_size must be > 1. '
|
||||||
|
'hash_bucket_size: {}'.format(hash_bucket_size))
|
||||||
|
if not keys or len(keys) < 2:
|
||||||
|
raise ValueError(
|
||||||
|
'keys must be a list with length > 1. Given: {}'.format(keys))
|
||||||
|
for key in keys:
|
||||||
|
if (not isinstance(key, six.string_types) and
|
||||||
|
not isinstance(key, _CategoricalColumn)):
|
||||||
|
raise ValueError(
|
||||||
|
'Unsupported key type. All keys must be either string, or '
|
||||||
|
'categorical column except _HashedCategoricalColumn. '
|
||||||
|
'Given: {}'.format(key))
|
||||||
|
if isinstance(key, _HashedCategoricalColumn):
|
||||||
|
raise ValueError(
|
||||||
|
'_HashedCategoricalColumn is not supported. Instead, use the feature '
|
||||||
|
'name as a string. Given: {}'.format(key))
|
||||||
|
return _CrossedColumn(
|
||||||
|
keys=tuple(keys), hash_bucket_size=hash_bucket_size,
|
||||||
|
hash_key=hash_key)
|
||||||
|
|
||||||
|
|
||||||
class _FeatureColumn(object):
|
class _FeatureColumn(object):
|
||||||
"""Represents a feature column abstraction.
|
"""Represents a feature column abstraction.
|
||||||
|
|
||||||
@ -1969,6 +2082,80 @@ class _WeightedCategoricalColumn(
|
|||||||
return _CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
|
return _CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
|
||||||
|
|
||||||
|
|
||||||
|
class _CrossedColumn(
|
||||||
|
_CategoricalColumn,
|
||||||
|
collections.namedtuple('_CrossedColumn',
|
||||||
|
['keys', 'hash_bucket_size', 'hash_key'])):
|
||||||
|
"""See `crossed_column`."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
feature_names = []
|
||||||
|
for key in _collect_leaf_level_keys(self):
|
||||||
|
if isinstance(key, _FeatureColumn):
|
||||||
|
feature_names.append(key.name)
|
||||||
|
else: # key must be a string
|
||||||
|
feature_names.append(key)
|
||||||
|
return '_X_'.join(sorted(feature_names))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _parse_example_config(self):
|
||||||
|
config = {}
|
||||||
|
for key in self.keys:
|
||||||
|
if isinstance(key, _FeatureColumn):
|
||||||
|
config.update(key._parse_example_config) # pylint: disable=protected-access
|
||||||
|
else: # key must be a string
|
||||||
|
config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _transform_feature(self, inputs):
|
||||||
|
feature_tensors = []
|
||||||
|
for key in _collect_leaf_level_keys(self):
|
||||||
|
if isinstance(key, six.string_types):
|
||||||
|
feature_tensors.append(inputs.get(key))
|
||||||
|
elif isinstance(key, _CategoricalColumn):
|
||||||
|
ids_and_weights = key._get_sparse_tensors(inputs) # pylint: disable=protected-access
|
||||||
|
if ids_and_weights.weight_tensor is not None:
|
||||||
|
raise ValueError(
|
||||||
|
'crossed_column does not support weight_tensor, but the given '
|
||||||
|
'column populates weight_tensor. '
|
||||||
|
'Given column: {}'.format(key.name))
|
||||||
|
feature_tensors.append(ids_and_weights.id_tensor)
|
||||||
|
else:
|
||||||
|
raise ValueError('Unsupported column type. Given: {}'.format(key))
|
||||||
|
return sparse_ops._sparse_cross_hashed( # pylint: disable=protected-access
|
||||||
|
inputs=feature_tensors,
|
||||||
|
num_buckets=self.hash_bucket_size,
|
||||||
|
hash_key=self.hash_key)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _num_buckets(self):
|
||||||
|
"""Returns number of buckets in this sparse feature."""
|
||||||
|
return self.hash_bucket_size
|
||||||
|
|
||||||
|
def _get_sparse_tensors(self, inputs, weight_collections=None,
|
||||||
|
trainable=None):
|
||||||
|
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_leaf_level_keys(cross):
|
||||||
|
"""Collects base keys by expanding all nested crosses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cross: A `_CrossedColumn`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of strings or `_CategoricalColumn` instances.
|
||||||
|
"""
|
||||||
|
leaf_level_keys = []
|
||||||
|
for k in cross.keys:
|
||||||
|
if isinstance(k, _CrossedColumn):
|
||||||
|
leaf_level_keys.extend(_collect_leaf_level_keys(k))
|
||||||
|
else:
|
||||||
|
leaf_level_keys.append(k)
|
||||||
|
return leaf_level_keys
|
||||||
|
|
||||||
|
|
||||||
# TODO(zakaria): Move this to embedding_ops and make it public.
|
# TODO(zakaria): Move this to embedding_ops and make it public.
|
||||||
def _safe_embedding_lookup_sparse(embedding_weights,
|
def _safe_embedding_lookup_sparse(embedding_weights,
|
||||||
sparse_ids,
|
sparse_ids,
|
||||||
|
@ -735,6 +735,262 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
self.assertAllClose(((4.,), (6.,)), predictions.eval())
|
self.assertAllClose(((4.,), (6.,)), predictions.eval())
|
||||||
|
|
||||||
|
|
||||||
|
class CrossedColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_keys_empty(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'keys must be a list with length > 1'):
|
||||||
|
fc.crossed_column([], 10)
|
||||||
|
|
||||||
|
def test_keys_length_one(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'keys must be a list with length > 1'):
|
||||||
|
fc.crossed_column(['a'], 10)
|
||||||
|
|
||||||
|
def test_key_type_unsupported(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Unsupported key type'):
|
||||||
|
fc.crossed_column(['a', fc.numeric_column('c')], 10)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, '_HashedCategoricalColumn is not supported'):
|
||||||
|
fc.crossed_column(
|
||||||
|
['a', fc.categorical_column_with_hash_bucket('c', 10)], 10)
|
||||||
|
|
||||||
|
def test_hash_bucket_size_negative(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'hash_bucket_size must be > 1'):
|
||||||
|
fc.crossed_column(['a', 'c'], -1)
|
||||||
|
|
||||||
|
def test_hash_bucket_size_zero(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'hash_bucket_size must be > 1'):
|
||||||
|
fc.crossed_column(['a', 'c'], 0)
|
||||||
|
|
||||||
|
def test_hash_bucket_size_none(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'hash_bucket_size must be > 1'):
|
||||||
|
fc.crossed_column(['a', 'c'], None)
|
||||||
|
|
||||||
|
def test_name(self):
|
||||||
|
a = fc.numeric_column('a', dtype=dtypes.int32)
|
||||||
|
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||||
|
crossed1 = fc.crossed_column(['d1', 'd2'], 10)
|
||||||
|
|
||||||
|
crossed2 = fc.crossed_column([b, 'c', crossed1], 10)
|
||||||
|
self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
|
||||||
|
|
||||||
|
def test_name_ordered_alphabetically(self):
|
||||||
|
"""Tests that the name does not depend on the order of given columns."""
|
||||||
|
a = fc.numeric_column('a', dtype=dtypes.int32)
|
||||||
|
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||||
|
crossed1 = fc.crossed_column(['d1', 'd2'], 10)
|
||||||
|
|
||||||
|
crossed2 = fc.crossed_column([crossed1, 'c', b], 10)
|
||||||
|
self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
|
||||||
|
|
||||||
|
def test_name_leaf_keys_ordered_alphabetically(self):
|
||||||
|
"""Tests that the name does not depend on the order of given columns."""
|
||||||
|
a = fc.numeric_column('a', dtype=dtypes.int32)
|
||||||
|
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||||
|
crossed1 = fc.crossed_column(['d2', 'c'], 10)
|
||||||
|
|
||||||
|
crossed2 = fc.crossed_column([crossed1, 'd1', b], 10)
|
||||||
|
self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
|
||||||
|
|
||||||
|
def test_parse_config(self):
|
||||||
|
a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32)
|
||||||
|
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||||
|
crossed = fc.crossed_column([b, 'c'], 10)
|
||||||
|
self.assertEqual({
|
||||||
|
'a': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32),
|
||||||
|
'c': parsing_ops.VarLenFeature(dtypes.string),
|
||||||
|
}, crossed._parse_example_config)
|
||||||
|
|
||||||
|
def test_num_buckets(self):
|
||||||
|
a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32)
|
||||||
|
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||||
|
crossed = fc.crossed_column([b, 'c'], 15)
|
||||||
|
self.assertEqual(15, crossed._num_buckets)
|
||||||
|
|
||||||
|
def test_deep_copy(self):
|
||||||
|
a = fc.numeric_column('a', dtype=dtypes.int32)
|
||||||
|
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||||
|
crossed1 = fc.crossed_column(['d1', 'd2'], 10)
|
||||||
|
crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5)
|
||||||
|
crossed2_copy = copy.deepcopy(crossed2)
|
||||||
|
self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2_copy.name,)
|
||||||
|
self.assertEqual(15, crossed2_copy.hash_bucket_size)
|
||||||
|
self.assertEqual(5, crossed2_copy.hash_key)
|
||||||
|
|
||||||
|
def test_parse_example(self):
|
||||||
|
price = fc.numeric_column('price', shape=[2])
|
||||||
|
bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
|
||||||
|
price_cross_wire = fc.crossed_column([bucketized_price, 'wire'], 10)
|
||||||
|
data = example_pb2.Example(features=feature_pb2.Features(
|
||||||
|
feature={
|
||||||
|
'price':
|
||||||
|
feature_pb2.Feature(float_list=feature_pb2.FloatList(
|
||||||
|
value=[20., 110.])),
|
||||||
|
'wire':
|
||||||
|
feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
|
||||||
|
value=[b'omar', b'stringer'])),
|
||||||
|
}))
|
||||||
|
features = parsing_ops.parse_example(
|
||||||
|
serialized=[data.SerializeToString()],
|
||||||
|
features=price_cross_wire._parse_example_config)
|
||||||
|
self.assertIn('price', features)
|
||||||
|
self.assertIn('wire', features)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAllEqual([[20., 110.]], features['price'].eval())
|
||||||
|
wire_sparse = features['wire']
|
||||||
|
self.assertAllEqual([[0, 0], [0, 1]], wire_sparse.indices.eval())
|
||||||
|
# Use byte constants to pass the open-source test.
|
||||||
|
self.assertAllEqual([b'omar', b'stringer'], wire_sparse.values.eval())
|
||||||
|
self.assertAllEqual([1, 2], wire_sparse.dense_shape.eval())
|
||||||
|
|
||||||
|
def test_get_sparse_tensors(self):
|
||||||
|
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||||
|
b = fc.bucketized_column(a, boundaries=(0, 1))
|
||||||
|
crossed1 = fc.crossed_column(['d1', 'd2'], 10)
|
||||||
|
crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5)
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
builder = fc._LazyBuilder({
|
||||||
|
'a': constant_op.constant(((-1., .5), (.5, 1.))),
|
||||||
|
'c': sparse_tensor.SparseTensor(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=['cA', 'cB', 'cC'],
|
||||||
|
dense_shape=(2, 2)),
|
||||||
|
'd1': sparse_tensor.SparseTensor(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=['d1A', 'd1B', 'd1C'],
|
||||||
|
dense_shape=(2, 2)),
|
||||||
|
'd2': sparse_tensor.SparseTensor(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=['d2A', 'd2B', 'd2C'],
|
||||||
|
dense_shape=(2, 2)),
|
||||||
|
})
|
||||||
|
id_weight_pair = crossed2._get_sparse_tensors(builder)
|
||||||
|
with _initialized_session():
|
||||||
|
id_tensor_eval = id_weight_pair.id_tensor.eval()
|
||||||
|
self.assertAllEqual(
|
||||||
|
((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
|
||||||
|
(1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (1, 12), (1, 13),
|
||||||
|
(1, 14), (1, 15)),
|
||||||
|
id_tensor_eval.indices)
|
||||||
|
# Check exact hashed output. If hashing changes this test will break.
|
||||||
|
# All values are within [0, hash_bucket_size).
|
||||||
|
expected_values = (
|
||||||
|
6, 14, 0, 13, 8, 8, 10, 12, 2, 0, 1, 9, 8, 12, 2, 0, 10, 11)
|
||||||
|
self.assertAllEqual(expected_values, id_tensor_eval.values)
|
||||||
|
self.assertAllEqual((2, 16), id_tensor_eval.dense_shape)
|
||||||
|
|
||||||
|
def test_get_sparse_tensors_simple(self):
|
||||||
|
"""Same as test_get_sparse_tensors, but with simpler values."""
|
||||||
|
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||||
|
b = fc.bucketized_column(a, boundaries=(0, 1))
|
||||||
|
crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
builder = fc._LazyBuilder({
|
||||||
|
'a': constant_op.constant(((-1., .5), (.5, 1.))),
|
||||||
|
'c': sparse_tensor.SparseTensor(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=['cA', 'cB', 'cC'],
|
||||||
|
dense_shape=(2, 2)),
|
||||||
|
})
|
||||||
|
id_weight_pair = crossed._get_sparse_tensors(builder)
|
||||||
|
with _initialized_session():
|
||||||
|
id_tensor_eval = id_weight_pair.id_tensor.eval()
|
||||||
|
self.assertAllEqual(
|
||||||
|
((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (1, 3)),
|
||||||
|
id_tensor_eval.indices)
|
||||||
|
# Check exact hashed output. If hashing changes this test will break.
|
||||||
|
# All values are within [0, hash_bucket_size).
|
||||||
|
expected_values = (1, 0, 1, 3, 4, 2)
|
||||||
|
self.assertAllEqual(expected_values, id_tensor_eval.values)
|
||||||
|
self.assertAllEqual((2, 4), id_tensor_eval.dense_shape)
|
||||||
|
|
||||||
|
def test_make_linear_model(self):
|
||||||
|
"""Tests make_linear_model.
|
||||||
|
|
||||||
|
Uses data from test_get_sparse_tesnsors_simple.
|
||||||
|
"""
|
||||||
|
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||||
|
b = fc.bucketized_column(a, boundaries=(0, 1))
|
||||||
|
crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
predictions = fc.make_linear_model({
|
||||||
|
'a': constant_op.constant(((-1., .5), (.5, 1.))),
|
||||||
|
'c': sparse_tensor.SparseTensor(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=['cA', 'cB', 'cC'],
|
||||||
|
dense_shape=(2, 2)),
|
||||||
|
}, (crossed,))
|
||||||
|
bias = get_linear_model_bias()
|
||||||
|
crossed_var = get_linear_model_column_var(crossed)
|
||||||
|
with _initialized_session() as sess:
|
||||||
|
self.assertAllClose((0.,), bias.eval())
|
||||||
|
self.assertAllClose(
|
||||||
|
((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
|
||||||
|
self.assertAllClose(((0.,), (0.,)), predictions.eval())
|
||||||
|
sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
|
||||||
|
# Expected ids after cross = (1, 0, 1, 3, 4, 2)
|
||||||
|
self.assertAllClose(((3.,), (14.,)), predictions.eval())
|
||||||
|
sess.run(bias.assign((.1,)))
|
||||||
|
self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
|
||||||
|
|
||||||
|
def test_make_linear_model_with_weights(self):
|
||||||
|
class _TestColumnWithWeights(fc._CategoricalColumn):
|
||||||
|
"""Produces sparse IDs and sparse weights."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return 'test_column'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _parse_example_config(self):
|
||||||
|
return {
|
||||||
|
self.name: parsing_ops.VarLenFeature(dtypes.int32),
|
||||||
|
'{}_weights'.format(self.name): parsing_ops.VarLenFeature(
|
||||||
|
dtypes.float32),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _num_buckets(self):
|
||||||
|
return 5
|
||||||
|
|
||||||
|
def _transform_feature(self, inputs):
|
||||||
|
return (inputs.get(self.name),
|
||||||
|
inputs.get('{}_weights'.format(self.name)))
|
||||||
|
|
||||||
|
def _get_sparse_tensors(self, inputs, weight_collections=None,
|
||||||
|
trainable=None):
|
||||||
|
"""Populates both id_tensor and weight_tensor."""
|
||||||
|
ids_and_weights = inputs.get(self)
|
||||||
|
return fc._CategoricalColumn.IdWeightPair(
|
||||||
|
id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
|
||||||
|
|
||||||
|
t = _TestColumnWithWeights()
|
||||||
|
crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
'crossed_column does not support weight_tensor.*{}'.format(t.name)):
|
||||||
|
fc.make_linear_model({
|
||||||
|
t.name: sparse_tensor.SparseTensor(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=[0, 1, 2],
|
||||||
|
dense_shape=(2, 2)),
|
||||||
|
'{}_weights'.format(t.name): sparse_tensor.SparseTensor(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=[1., 10., 2.],
|
||||||
|
dense_shape=(2, 2)),
|
||||||
|
'c': sparse_tensor.SparseTensor(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=['cA', 'cB', 'cC'],
|
||||||
|
dense_shape=(2, 2)),
|
||||||
|
}, (crossed,))
|
||||||
|
|
||||||
|
|
||||||
def get_linear_model_bias():
|
def get_linear_model_bias():
|
||||||
with variable_scope.variable_scope('make_linear_model', reuse=True):
|
with variable_scope.variable_scope('make_linear_model', reuse=True):
|
||||||
return variable_scope.get_variable('bias_weights')
|
return variable_scope.get_variable('bias_weights')
|
||||||
|
Loading…
Reference in New Issue
Block a user