Move bucketized_column to core.

Change: 154963963
This commit is contained in:
A. Unique TensorFlower 2017-05-03 08:20:17 -08:00 committed by TensorFlower Gardener
parent 27aaf4a653
commit b93dd62e8a
2 changed files with 401 additions and 7 deletions

View File

@ -200,13 +200,13 @@ def make_linear_model(features,
builder = _LazyBuilder(features) builder = _LazyBuilder(features)
for column in sorted(feature_columns, key=lambda x: x.name): for column in sorted(feature_columns, key=lambda x: x.name):
with variable_scope.variable_scope(None, default_name=column.name): with variable_scope.variable_scope(None, default_name=column.name):
if isinstance(column, _DenseColumn): if isinstance(column, _CategoricalColumn):
weigthed_sums.append(_create_dense_column_weighted_sum(
column, builder, units, weight_collections, trainable))
else:
weigthed_sums.append(_create_categorical_column_weighted_sum( weigthed_sums.append(_create_categorical_column_weighted_sum(
column, builder, units, sparse_combiner, weight_collections, column, builder, units, sparse_combiner, weight_collections,
trainable)) trainable))
else:
weigthed_sums.append(_create_dense_column_weighted_sum(
column, builder, units, weight_collections, trainable))
predictions_no_bias = math_ops.add_n( predictions_no_bias = math_ops.add_n(
weigthed_sums, name='weighted_sum_no_bias') weigthed_sums, name='weighted_sum_no_bias')
bias = variable_scope.get_variable( bias = variable_scope.get_variable(
@ -237,7 +237,7 @@ def numeric_column(key,
# or # or
bucketized_price = bucketized_column(price, boundaries=[...]) bucketized_price = bucketized_column(price, boundaries=[...])
all_feature_columns = [bucketized_price, ...] all_feature_columns = [bucketized_price, ...]
linear_prediction, _, _ = make_linear_model(features, all_feature_columns) linear_prediction = make_linear_model(features, all_feature_columns)
``` ```
@ -291,6 +291,55 @@ def numeric_column(key,
normalizer_fn=normalizer_fn) normalizer_fn=normalizer_fn)
def bucketized_column(source_column, boundaries):
"""Represents discretized dense input.
Buckets include the left boundary, and exclude the right boundary. Namely,
`boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`,
`[1., 2.)`, and `[2., +inf)`.
An example:
```python
price = numeric_column('price')
bucketized_price = bucketized_column(price, boundaries=[...])
all_feature_columns = [bucketized_price, ...]
linear_prediction = make_linear_model(features, all_feature_columns)
# or
all_feature_columns = [bucketized_price, ...]
dense_tensor = make_input_layer(features, all_feature_columns)
```
Args:
source_column: A one-dimensional dense column which is generated with
`numeric_column`.
boundaries: A sorted list or tuple of floats specifying the boundaries.
Returns:
A `_BucketizedColumn`.
Raises:
ValueError: If `source_column` is not a numeric column, or if it is not
one-dimensional.
ValueError: If `boundaries` is not a sorted list or tuple.
"""
if not isinstance(source_column, _NumericColumn):
raise ValueError(
'source_column must be a column generated with numeric_column(). '
'Given: {}'.format(source_column))
if len(source_column.shape) > 1:
raise ValueError(
'source_column must be one-dimensional column. '
'Given: {}'.format(source_column))
if (not boundaries or
not (isinstance(boundaries, list) or isinstance(boundaries, tuple))):
raise ValueError('boundaries must be a sorted list.')
for i in range(len(boundaries) - 1):
if boundaries[i] >= boundaries[i + 1]:
raise ValueError('boundaries must be a sorted list.')
return _BucketizedColumn(source_column, tuple(boundaries))
def categorical_column_with_hash_bucket(key, def categorical_column_with_hash_bucket(key,
hash_bucket_size, hash_bucket_size,
dtype=dtypes.string): dtype=dtypes.string):
@ -303,8 +352,8 @@ def categorical_column_with_hash_bucket(key,
An example: An example:
```python ```python
keywords = categorical_column_with_hash_bucket("keywords", 10K) keywords = categorical_column_with_hash_bucket("keywords", 10K)
linear_prediction, _, _ = make_linear_model(features, all_feature_columns)
all_feature_columns = [keywords, ...] all_feature_columns = [keywords, ...]
linear_prediction = make_linear_model(features, all_feature_columns)
# or # or
keywords_embedded = embedding_column(keywords, 16) keywords_embedded = embedding_column(keywords, 16)
@ -668,6 +717,73 @@ class _NumericColumn(_DenseColumn,
return inputs.get(self) return inputs.get(self)
class _BucketizedColumn(_DenseColumn, _CategoricalColumn,
collections.namedtuple('_BucketizedColumn', [
'source_column', 'boundaries'])):
"""See `bucketized_column`."""
@property
def name(self):
return '{}_bucketized'.format(self.source_column.name)
@property
def _parse_example_config(self):
return self.source_column._parse_example_config # pylint: disable=protected-access
def _transform_feature(self, inputs):
source_tensor = inputs.get(self.source_column)
return math_ops._bucketize( # pylint: disable=protected-access
source_tensor,
boundaries=self.boundaries)
@property
def _variable_shape(self):
return tuple(self.source_column.shape) + (len(self.boundaries) + 1,)
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
del weight_collections
del trainable
input_tensor = inputs.get(self)
return array_ops.one_hot(
indices=math_ops.to_int64(input_tensor),
depth=len(self.boundaries) + 1,
on_value=1.,
off_value=0.)
@property
def _num_buckets(self):
# By construction, source_column is always one-dimensional.
return (len(self.boundaries) + 1) * self.source_column.shape[0]
def _get_sparse_tensors(self, inputs, weight_collections=None,
trainable=None):
input_tensor = inputs.get(self)
batch_size = array_ops.shape(input_tensor)[0]
# By construction, source_column is always one-dimensional.
source_dimension = self.source_column.shape[0]
i1 = array_ops.reshape(
array_ops.tile(
array_ops.expand_dims(math_ops.range(0, batch_size), 1),
[1, source_dimension]),
(-1,))
i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
# Flatten the bucket indices and unique them across dimensions
# E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
bucket_indices = (
array_ops.reshape(input_tensor, (-1,)) +
(len(self.boundaries) + 1) * i2)
indices = math_ops.to_int64(array_ops.transpose(array_ops.stack((i1, i2))))
dense_shape = math_ops.to_int64(array_ops.stack(
[batch_size, source_dimension]))
sparse_tensor = sparse_tensor_lib.SparseTensor(
indices=indices,
values=bucket_indices,
dense_shape=dense_shape)
return _CategoricalColumn.IdWeightPair(sparse_tensor, None)
def _create_tuple(shape, value): def _create_tuple(shape, value):
"""Returns a tuple with given shape and filled with value.""" """Returns a tuple with given shape and filled with value."""
if shape: if shape:

View File

@ -151,7 +151,7 @@ class LazyColumnTest(test.TestCase):
builder.get(NotAFeatureColumn()) builder.get(NotAFeatureColumn())
class NumericalColumnTest(test.TestCase): class NumericColumnTest(test.TestCase):
def test_defaults(self): def test_defaults(self):
a = fc.numeric_column('aaa') a = fc.numeric_column('aaa')
@ -327,6 +327,231 @@ class NumericalColumnTest(test.TestCase):
self.assertAllClose([[10.], [50.]], predictions.eval()) self.assertAllClose([[10.], [50.]], predictions.eval())
class BucketizedColumnTest(test.TestCase):
def test_invalid_source_column_type(self):
a = fc.categorical_column_with_hash_bucket('aaa', hash_bucket_size=10)
with self.assertRaisesRegexp(
ValueError,
'source_column must be a column generated with numeric_column'):
fc.bucketized_column(a, boundaries=[0, 1])
def test_invalid_source_column_shape(self):
a = fc.numeric_column('aaa', shape=[2, 3])
with self.assertRaisesRegexp(
ValueError, 'source_column must be one-dimensional column'):
fc.bucketized_column(a, boundaries=[0, 1])
def test_invalid_boundaries(self):
a = fc.numeric_column('aaa')
with self.assertRaisesRegexp(
ValueError, 'boundaries must be a sorted list'):
fc.bucketized_column(a, boundaries=None)
with self.assertRaisesRegexp(
ValueError, 'boundaries must be a sorted list'):
fc.bucketized_column(a, boundaries=1.)
with self.assertRaisesRegexp(
ValueError, 'boundaries must be a sorted list'):
fc.bucketized_column(a, boundaries=[1, 0])
with self.assertRaisesRegexp(
ValueError, 'boundaries must be a sorted list'):
fc.bucketized_column(a, boundaries=[1, 1])
def test_name(self):
a = fc.numeric_column('aaa', dtype=dtypes.int32)
b = fc.bucketized_column(a, boundaries=[0, 1])
self.assertEqual('aaa_bucketized', b.name)
def test_parse_config(self):
a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
b = fc.bucketized_column(a, boundaries=[0, 1])
self.assertEqual({
'aaa': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32)
}, b._parse_example_config)
def test_variable_shape(self):
a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
b = fc.bucketized_column(a, boundaries=[0, 1])
# Column 'aaa` has shape [2] times three buckets -> variable_shape=[2, 3].
self.assertAllEqual((2, 3), b._variable_shape)
def test_num_buckets(self):
a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
b = fc.bucketized_column(a, boundaries=[0, 1])
# Column 'aaa` has shape [2] times three buckets -> num_buckets=6.
self.assertEqual(6, b._num_buckets)
def test_parse_example(self):
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
data = example_pb2.Example(features=feature_pb2.Features(
feature={
'price':
feature_pb2.Feature(float_list=feature_pb2.FloatList(
value=[20., 110.]))
}))
features = parsing_ops.parse_example(
serialized=[data.SerializeToString()],
features=bucketized_price._parse_example_config)
self.assertIn('price', features)
with self.test_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
def test_transform_feature(self):
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
builder = fc._LazyBuilder({
'price': constant_op.constant([[-1., 1.], [5., 6.]])
})
transformed_tensor = builder.get(bucketized_price)
with _initialized_session():
self.assertAllEqual([[0, 1], [3, 4]], transformed_tensor.eval())
def test_get_dense_tensor_one_input_value(self):
"""Tests _get_dense_tensor() for input with shape=[1]."""
price = fc.numeric_column('price', shape=[1])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
builder = fc._LazyBuilder({
'price': constant_op.constant([[-1.], [1.], [5.], [6.]])
})
with _initialized_session():
bucketized_price_tensor = bucketized_price._get_dense_tensor(builder)
self.assertAllClose(
# One-hot tensor.
[[[1., 0., 0., 0., 0.]],
[[0., 1., 0., 0., 0.]],
[[0., 0., 0., 1., 0.]],
[[0., 0., 0., 0., 1.]]],
bucketized_price_tensor.eval())
def test_get_dense_tensor_two_input_values(self):
"""Tests _get_dense_tensor() for input with shape=[2]."""
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
builder = fc._LazyBuilder({
'price': constant_op.constant([[-1., 1.], [5., 6.]])
})
with _initialized_session():
bucketized_price_tensor = bucketized_price._get_dense_tensor(builder)
self.assertAllClose(
# One-hot tensor.
[[[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.]],
[[0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]]],
bucketized_price_tensor.eval())
def test_get_sparse_tensors_one_input_value(self):
"""Tests _get_sparse_tensors() for input with shape=[1]."""
price = fc.numeric_column('price', shape=[1])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
builder = fc._LazyBuilder({
'price': constant_op.constant([[-1.], [1.], [5.], [6.]])
})
with _initialized_session() as sess:
id_weight_pair = bucketized_price._get_sparse_tensors(builder)
self.assertIsNone(id_weight_pair.weight_tensor)
id_tensor_value = sess.run(id_weight_pair.id_tensor)
self.assertAllEqual(
[[0, 0], [1, 0], [2, 0], [3, 0]], id_tensor_value.indices)
self.assertAllEqual([0, 1, 3, 4], id_tensor_value.values)
self.assertAllEqual([4, 1], id_tensor_value.dense_shape)
def test_get_sparse_tensors_two_input_values(self):
"""Tests _get_sparse_tensors() for input with shape=[2]."""
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
builder = fc._LazyBuilder({
'price': constant_op.constant([[-1., 1.], [5., 6.]])
})
with _initialized_session() as sess:
id_weight_pair = bucketized_price._get_sparse_tensors(builder)
self.assertIsNone(id_weight_pair.weight_tensor)
id_tensor_value = sess.run(id_weight_pair.id_tensor)
self.assertAllEqual(
[[0, 0], [0, 1], [1, 0], [1, 1]], id_tensor_value.indices)
# Values 0-4 correspond to the first column of the input price.
# Values 5-9 correspond to the second column of the input price.
self.assertAllEqual([0, 6, 3, 9], id_tensor_value.values)
self.assertAllEqual([2, 2], id_tensor_value.dense_shape)
def test_sparse_tensor_input_not_supported(self):
price = fc.numeric_column('price')
bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])
builder = fc._LazyBuilder({
'price':
sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[0.3], dense_shape=[1, 1])
})
with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
bucketized_price._transform_feature(builder)
def test_deep_copy(self):
a = fc.numeric_column('aaa', shape=[2])
a_bucketized = fc.bucketized_column(a, boundaries=[0, 1])
a_bucketized_copy = copy.deepcopy(a_bucketized)
self.assertEqual(a_bucketized_copy.name, 'aaa_bucketized')
self.assertAllEqual(a_bucketized_copy._variable_shape, (2, 3))
self.assertEqual(a_bucketized_copy.boundaries, (0, 1))
def test_make_linear_model_one_input_value(self):
"""Tests make_linear_model() for input with shape=[1]."""
price = fc.numeric_column('price', shape=[1])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': constant_op.constant([[-1.], [1.], [5.], [6.]])}
predictions = fc.make_linear_model(features, [bucketized_price])
bias = get_linear_model_bias()
bucketized_price_var = get_linear_model_column_var(bucketized_price)
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight variable per bucket, all initialized to zero.
self.assertAllClose(
[[0.], [0.], [0.], [0.], [0.]], bucketized_price_var.eval())
self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
sess.run(bucketized_price_var.assign(
[[10.], [20.], [30.], [40.], [50.]]))
# price -1. is in the 0th bucket, whose weight is 10.
# price 1. is in the 1st bucket, whose weight is 20.
# price 5. is in the 3rd bucket, whose weight is 40.
# price 6. is in the 4th bucket, whose weight is 50.
self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
sess.run(bias.assign([1.]))
self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
def test_make_linear_model_two_input_values(self):
"""Tests make_linear_model() for input with shape=[2]."""
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': constant_op.constant([[-1., 1.], [5., 6.]])}
predictions = fc.make_linear_model(features, [bucketized_price])
bias = get_linear_model_bias()
bucketized_price_var = get_linear_model_column_var(bucketized_price)
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight per bucket per input column, all initialized to zero.
self.assertAllClose(
[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
bucketized_price_var.eval())
self.assertAllClose([[0.], [0.]], predictions.eval())
sess.run(bucketized_price_var.assign(
[[10.], [20.], [30.], [40.], [50.],
[60.], [70.], [80.], [90.], [100.]]))
# 1st example:
# price -1. is in the 0th bucket, whose weight is 10.
# price 1. is in the 6th bucket, whose weight is 70.
# 2nd example:
# price 5. is in the 3rd bucket, whose weight is 40.
# price 6. is in the 9th bucket, whose weight is 100.
self.assertAllClose([[80.], [140.]], predictions.eval())
sess.run(bias.assign([1.]))
self.assertAllClose([[81.], [141.]], predictions.eval())
class SparseColumnHashedTest(test.TestCase): class SparseColumnHashedTest(test.TestCase):
def test_defaults(self): def test_defaults(self):
@ -567,6 +792,59 @@ class MakeLinearModelTest(test.TestCase):
sess.run(price_var.assign([[10.]])) sess.run(price_var.assign([[10.]]))
self.assertAllClose([[1015.], [10065.]], predictions.eval()) self.assertAllClose([[1015.], [10065.]], predictions.eval())
def test_dense_and_sparse_column(self):
"""When the column is both dense and sparse, uses sparse tensors."""
class _DenseAndSparseColumn(fc._DenseColumn, fc._CategoricalColumn):
@property
def name(self):
return 'dense_and_sparse_column'
@property
def _parse_example_config(self):
return {self.name: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
return inputs.get(self.name)
@property
def _variable_shape(self):
raise ValueError('Should not use this method.')
def _get_dense_tensor(self, inputs, weight_collections=None,
trainable=None):
raise ValueError('Should not use this method.')
@property
def _num_buckets(self):
return 4
def _get_sparse_tensors(self, inputs, weight_collections=None,
trainable=None):
sp_tensor = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 0], [1, 1]],
values=[2, 0, 3],
dense_shape=[2, 2])
return fc._CategoricalColumn.IdWeightPair(sp_tensor, None)
dense_and_sparse_column = _DenseAndSparseColumn()
with ops.Graph().as_default():
sp_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'],
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {dense_and_sparse_column.name: sp_tensor}
predictions = fc.make_linear_model(features, [dense_and_sparse_column])
bias = get_linear_model_bias()
dense_and_sparse_column_var = get_linear_model_column_var(
dense_and_sparse_column)
with _initialized_session() as sess:
sess.run(dense_and_sparse_column_var.assign(
[[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [10015.]], predictions.eval())
def test_dense_multi_output(self): def test_dense_multi_output(self):
price = fc.numeric_column('price') price = fc.numeric_column('price')
with ops.Graph().as_default(): with ops.Graph().as_default():