Add categorical_column_with_vocabulary_file
.
Move lookup_ops implementation from tensorflow/contrib/lookup to tensorflow/python/feature_column. Change: 155079825
This commit is contained in:
parent
0374fd18c2
commit
42c7659edd
tensorflow
@ -203,6 +203,7 @@ add_python_module("tensorflow/python/estimator")
|
||||
add_python_module("tensorflow/python/estimator/export")
|
||||
add_python_module("tensorflow/python/estimator/inputs")
|
||||
add_python_module("tensorflow/python/estimator/inputs/queues")
|
||||
add_python_module("tensorflow/python/feature_column")
|
||||
add_python_module("tensorflow/python/framework")
|
||||
add_python_module("tensorflow/python/grappler")
|
||||
add_python_module("tensorflow/python/kernel_tests")
|
||||
|
@ -13,19 +13,10 @@ py_library(
|
||||
name = "lookup_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"lookup_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:data_flow_ops_gen",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/feature_column:lookup_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -47,7 +47,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.lookup.lookup_ops import *
|
||||
from tensorflow.python.feature_column.lookup_ops import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
@ -82,6 +82,7 @@ py_library(
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
"//tensorflow/python/feature_column:feature_column",
|
||||
"//tensorflow/python/feature_column:lookup_ops",
|
||||
"//tensorflow/python/ops/losses",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//tensorflow/python/saved_model",
|
||||
@ -1021,7 +1022,7 @@ tf_gen_op_wrapper_private_py(
|
||||
require_shape_functions = True,
|
||||
visibility = [
|
||||
"//learning/brain/python/ops:__pkg__",
|
||||
"//tensorflow/contrib/lookup:__pkg__",
|
||||
"//tensorflow/python/feature_column:__pkg__",
|
||||
"//tensorflow/python/kernel_tests:__pkg__",
|
||||
],
|
||||
)
|
||||
|
@ -29,6 +29,7 @@ py_library(
|
||||
srcs = ["feature_column.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":lookup_ops",
|
||||
"//tensorflow/python:embedding_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:init_ops",
|
||||
@ -44,14 +45,47 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "vocabulary_testdata",
|
||||
srcs = [
|
||||
"testdata/warriors_vocabulary.txt",
|
||||
"testdata/wire_vocabulary.txt",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "feature_column_test",
|
||||
srcs = ["feature_column_test.py"],
|
||||
data = [":vocabulary_testdata"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":feature_column",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(ptucker,yleon): Move along with 3p/tf/contrib/lookup.
|
||||
# Test is still in 3p/tf/contrib/lookup.
|
||||
py_library(
|
||||
name = "lookup_ops",
|
||||
srcs = [
|
||||
"lookup_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:data_flow_ops_gen",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
@ -121,6 +121,7 @@ from __future__ import print_function
|
||||
import abc
|
||||
import collections
|
||||
|
||||
from tensorflow.python.feature_column import lookup_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
@ -331,7 +332,9 @@ def numeric_column(key,
|
||||
```
|
||||
|
||||
Args:
|
||||
key: A string providing key to look up corresponding `Tensor`.
|
||||
key: A unique string identifying the input feature. It is used as the
|
||||
column name and the dictionary key for feature parsing configs, feature
|
||||
`Tensor` objects, and feature columns.
|
||||
shape: An iterable of integers specifies the shape of the `Tensor`. An
|
||||
integer can be given which means a single dimension `Tensor` with given
|
||||
width. The `Tensor` representing the column will have the shape of
|
||||
@ -443,22 +446,22 @@ def categorical_column_with_hash_bucket(key,
|
||||
|
||||
```python
|
||||
keywords = categorical_column_with_hash_bucket("keywords", 10K)
|
||||
all_feature_columns = [keywords, ...]
|
||||
linear_prediction = make_linear_model(features, all_feature_columns)
|
||||
linear_prediction = make_linear_model(features, [keywords, ...])
|
||||
|
||||
# or
|
||||
keywords_embedded = embedding_column(keywords, 16)
|
||||
all_feature_columns = [keywords_embedded, ...]
|
||||
dense_tensor = make_input_layer(features, all_feature_columns)
|
||||
dense_tensor = make_input_layer(features, [keywords_embedded, ...])
|
||||
```
|
||||
|
||||
Args:
|
||||
key: A string providing key to look up corresponding `Tensor`.
|
||||
key: A unique string identifying the input feature. It is used as the
|
||||
column name and the dictionary key for feature parsing configs, feature
|
||||
`Tensor` objects, and feature columns.
|
||||
hash_bucket_size: An int > 1. The number of buckets.
|
||||
dtype: The type of features. Only string and integer types are supported.
|
||||
|
||||
Returns:
|
||||
A `_CategoricalColumnHashed`.
|
||||
A `_HashedCategoricalColumn`.
|
||||
|
||||
Raises:
|
||||
ValueError: `hash_bucket_size` is not greater than 1.
|
||||
@ -476,7 +479,100 @@ def categorical_column_with_hash_bucket(key,
|
||||
raise ValueError('dtype must be string or integer. '
|
||||
'dtype: {}, column_name: {}'.format(dtype, key))
|
||||
|
||||
return _CategoricalColumnHashed(key, hash_bucket_size, dtype)
|
||||
return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
|
||||
|
||||
|
||||
def categorical_column_with_vocabulary_file(
|
||||
key, vocabulary_file, vocabulary_size, num_oov_buckets=0,
|
||||
default_value=None, dtype=dtypes.string):
|
||||
"""Creates a `_CategoricalColumn` with vocabulary file configuration.
|
||||
|
||||
Use this when your inputs are in string or integer format, and you have a
|
||||
vocabulary file that maps each value to an integer ID. By default,
|
||||
out-of-vocabulary values are ignored. Use either (but not both) of
|
||||
`num_oov_buckets` and `default_value` to specify how to include
|
||||
out-of-vocabulary values.
|
||||
|
||||
Inputs can be either `Tensor` or `SparseTensor`. If `Tensor`, missing values
|
||||
can be represented by `-1` for int and `''` for string. Note that these values
|
||||
are independent of the `default_value` argument.
|
||||
|
||||
Example with `num_oov_buckets`:
|
||||
File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
|
||||
abbreviation. All inputs with values in that file are assigned an ID 0-49,
|
||||
corresponding to its line number. All other values are hashed and assigned an
|
||||
ID 50-54.
|
||||
```python
|
||||
states = categorical_column_with_vocabulary_file(
|
||||
key='keywords', vocabulary_file='/us/states.txt', vocabulary_size=50,
|
||||
num_oov_buckets=5)
|
||||
linear_prediction = make_linear_model(features, [states, ...])
|
||||
```
|
||||
|
||||
Example with `default_value`:
|
||||
File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
|
||||
other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
|
||||
in input, and other values missing from the file, will be assigned ID 0. All
|
||||
others are assigned the corresponding line number 1-50.
|
||||
```python
|
||||
states = categorical_column_with_vocabulary_file(
|
||||
key='keywords', vocabulary_file='/us/states.txt', vocabulary_size=51,
|
||||
default_value=0)
|
||||
linear_prediction, _, _ = make_linear_model(features, [states, ...])
|
||||
|
||||
And to make an embedding with either:
|
||||
```python
|
||||
dense_tensor = make_input_layer(features, [embedding_column(states, 3),...])
|
||||
```
|
||||
|
||||
Args:
|
||||
key: A unique string identifying the input feature. It is used as the
|
||||
column name and the dictionary key for feature parsing configs, feature
|
||||
`Tensor` objects, and feature columns.
|
||||
vocabulary_file: The vocabulary file name.
|
||||
vocabulary_size: Number of the elements in the vocabulary.
|
||||
num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
|
||||
buckets. All out-of-vocabulary inputs will be assigned IDs in the range
|
||||
`[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
|
||||
the input value. A positive `num_oov_buckets` can not be specified with
|
||||
`default_value`.
|
||||
default_value: The integer ID value to return for out-of-vocabulary feature
|
||||
values, defaults to -1. This can not be specified with a positive
|
||||
`num_oov_buckets`.
|
||||
dtype: The type of features. Only string and integer types are supported.
|
||||
|
||||
Returns:
|
||||
A `_CategoricalColumn` with vocabulary file configuration.
|
||||
|
||||
Raises:
|
||||
ValueError: `vocabulary_file` is missing.
|
||||
ValueError: `vocabulary_size` is missing or < 1.
|
||||
ValueError: `num_oov_buckets` is not a non-negative integer.
|
||||
ValueError: `dtype` is neither string nor integer.
|
||||
"""
|
||||
if not vocabulary_file:
|
||||
raise ValueError('Missing vocabulary_file in {}.'.format(key))
|
||||
# `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
|
||||
# TODO(ptucker): Should we fail for vocabulary_size==1?
|
||||
if (vocabulary_size is None) or (vocabulary_size < 1):
|
||||
raise ValueError('Invalid vocabulary_size in {}.'.format(key))
|
||||
if num_oov_buckets:
|
||||
if default_value is not None:
|
||||
raise ValueError(
|
||||
'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
|
||||
key))
|
||||
if num_oov_buckets < 0:
|
||||
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
|
||||
num_oov_buckets, key))
|
||||
if dtype != dtypes.string and not dtype.is_integer:
|
||||
raise ValueError('Invalid dtype {} in {}.'.format(dtype, key))
|
||||
return _VocabularyCategoricalColumn(
|
||||
key=key,
|
||||
vocabulary_file=vocabulary_file,
|
||||
vocabulary_size=vocabulary_size,
|
||||
num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets,
|
||||
default_value=-1 if default_value is None else default_value,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
class _FeatureColumn(object):
|
||||
@ -764,6 +860,67 @@ class _LazyBuilder(object):
|
||||
return transformed
|
||||
|
||||
|
||||
# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
|
||||
def _shape_offsets(shape):
|
||||
"""Returns moving offset for each dimension given shape."""
|
||||
offsets = []
|
||||
for dim in reversed(shape):
|
||||
if offsets:
|
||||
offsets.append(dim * offsets[-1])
|
||||
else:
|
||||
offsets.append(dim)
|
||||
offsets.reverse()
|
||||
return offsets
|
||||
|
||||
|
||||
# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
|
||||
def _to_sparse_input(input_tensor, ignore_value=None):
|
||||
"""Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
|
||||
|
||||
If `input_tensor` is already a `SparseTensor`, just return it.
|
||||
|
||||
Args:
|
||||
input_tensor: A string or integer `Tensor`.
|
||||
ignore_value: Entries in `dense_tensor` equal to this value will be
|
||||
absent from the resulting `SparseTensor`. If `None`, default value of
|
||||
`dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`).
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` with the same shape as `input_tensor`.
|
||||
|
||||
Raises:
|
||||
ValueError: when `input_tensor`'s rank is `None`.
|
||||
"""
|
||||
input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
|
||||
input_tensor)
|
||||
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
|
||||
return input_tensor
|
||||
with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)):
|
||||
input_rank = input_tensor.get_shape().ndims
|
||||
if input_rank is None:
|
||||
# TODO(b/32318825): Implement dense_to_sparse_tensor for undefined rank.
|
||||
raise ValueError('Undefined input_tensor shape.')
|
||||
if ignore_value is None:
|
||||
ignore_value = '' if input_tensor.dtype == dtypes.string else -1
|
||||
dense_shape = math_ops.cast(array_ops.shape(input_tensor), dtypes.int64)
|
||||
indices = array_ops.where(math_ops.not_equal(
|
||||
input_tensor, math_ops.cast(ignore_value, input_tensor.dtype)))
|
||||
# Flattens the tensor and indices for use with gather.
|
||||
flat_tensor = array_ops.reshape(input_tensor, [-1])
|
||||
flat_indices = indices[:, input_rank - 1]
|
||||
# Computes the correct flattened indices for 2d (or higher) tensors.
|
||||
if input_rank > 1:
|
||||
higher_dims = indices[:, :input_rank - 1]
|
||||
shape_offsets = array_ops.stack(
|
||||
_shape_offsets(array_ops.unstack(dense_shape)[1:]))
|
||||
offsets = math_ops.reduce_sum(
|
||||
math_ops.multiply(higher_dims, shape_offsets),
|
||||
reduction_indices=[1])
|
||||
flat_indices = math_ops.add(flat_indices, offsets)
|
||||
values = array_ops.gather(flat_tensor, flat_indices)
|
||||
return sparse_tensor_lib.SparseTensor(indices, values, dense_shape)
|
||||
|
||||
|
||||
def _check_feature_columns(feature_columns):
|
||||
if isinstance(feature_columns, dict):
|
||||
raise ValueError('Expected feature_columns to be iterable, found dict.')
|
||||
@ -951,7 +1108,7 @@ def _check_default_value(shape, default_value, dtype, key):
|
||||
`shape`.
|
||||
dtype: defines the type of values. Default value is `tf.float32`. Must be a
|
||||
non-quantized, real integer or floating point type.
|
||||
key: A string providing key to look up corresponding `Tensor`.
|
||||
key: Column name, used only for error messages.
|
||||
|
||||
Returns:
|
||||
A tuple which will be used as default value.
|
||||
@ -994,9 +1151,9 @@ def _check_default_value(shape, default_value, dtype, key):
|
||||
default_value, dtype, key))
|
||||
|
||||
|
||||
class _CategoricalColumnHashed(
|
||||
class _HashedCategoricalColumn(
|
||||
_CategoricalColumn,
|
||||
collections.namedtuple('_CategoricalColumnHashed',
|
||||
collections.namedtuple('_HashedCategoricalColumn',
|
||||
['key', 'hash_bucket_size', 'dtype'])):
|
||||
"""see `categorical_column_with_hash_bucket`."""
|
||||
|
||||
@ -1009,7 +1166,7 @@ class _CategoricalColumnHashed(
|
||||
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
|
||||
|
||||
def _transform_feature(self, inputs):
|
||||
input_tensor = inputs.get(self.key)
|
||||
input_tensor = _to_sparse_input(inputs.get(self.key))
|
||||
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
|
||||
raise ValueError('SparseColumn input must be a SparseTensor.')
|
||||
|
||||
@ -1045,6 +1202,58 @@ class _CategoricalColumnHashed(
|
||||
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
||||
|
||||
|
||||
class _VocabularyCategoricalColumn(
|
||||
_CategoricalColumn, collections.namedtuple('_VocabularyCategoricalColumn', (
|
||||
'key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'dtype',
|
||||
'default_value'
|
||||
))):
|
||||
"""See `categorical_column_with_vocabulary_file`."""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.key
|
||||
|
||||
@property
|
||||
def _parse_example_config(self):
|
||||
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
|
||||
|
||||
def _transform_feature(self, inputs):
|
||||
input_tensor = _to_sparse_input(inputs.get(self.key))
|
||||
|
||||
if self.dtype.is_integer != input_tensor.dtype.is_integer:
|
||||
raise ValueError(
|
||||
'Column dtype and SparseTensors dtype must be compatible. '
|
||||
'key: {}, column dtype: {}, tensor dtype: {}'.format(
|
||||
self.key, self.dtype, input_tensor.dtype))
|
||||
|
||||
key_dtype = self.dtype
|
||||
if input_tensor.dtype.is_integer:
|
||||
# `index_table_from_file` requires 64-bit integer keys.
|
||||
key_dtype = dtypes.int64
|
||||
input_tensor = math_ops.to_int64(input_tensor)
|
||||
elif input_tensor.dtype != dtypes.string:
|
||||
raise ValueError('input tensors dtype must be string or integer. '
|
||||
'dtype: {}, column_name: {}'.format(
|
||||
input_tensor.dtype, self.key))
|
||||
|
||||
return lookup_ops.index_table_from_file(
|
||||
vocabulary_file=self.vocabulary_file,
|
||||
num_oov_buckets=self.num_oov_buckets,
|
||||
vocab_size=self.vocabulary_size,
|
||||
default_value=self.default_value,
|
||||
key_dtype=key_dtype,
|
||||
name='{}_lookup'.format(self.key)).lookup(input_tensor)
|
||||
|
||||
@property
|
||||
def _num_buckets(self):
|
||||
"""Returns number of buckets in this sparse feature."""
|
||||
return self.vocabulary_size + self.num_oov_buckets
|
||||
|
||||
def _get_sparse_tensors(
|
||||
self, inputs, weight_collections=None, trainable=None):
|
||||
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
||||
|
||||
|
||||
# TODO(zakaria): Move this to embedding_ops and make it public.
|
||||
def _safe_embedding_lookup_sparse(embedding_weights,
|
||||
sparse_ids,
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.client import session
|
||||
from tensorflow.python.feature_column import feature_column as fc
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
@ -552,7 +553,7 @@ class BucketizedColumnTest(test.TestCase):
|
||||
self.assertAllClose([[81.], [141.]], predictions.eval())
|
||||
|
||||
|
||||
class SparseColumnHashedTest(test.TestCase):
|
||||
class HashedCategoricalColumnTest(test.TestCase):
|
||||
|
||||
def test_defaults(self):
|
||||
a = fc.categorical_column_with_hash_bucket('aaa', 10)
|
||||
@ -578,11 +579,14 @@ class SparseColumnHashedTest(test.TestCase):
|
||||
|
||||
def test_deep_copy(self):
|
||||
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
||||
column = fc.categorical_column_with_hash_bucket('aaa', 10)
|
||||
column_copy = copy.deepcopy(column)
|
||||
self.assertEqual('aaa', column_copy.name)
|
||||
self.assertEqual(10, column_copy.hash_bucket_size)
|
||||
self.assertEqual(dtypes.string, column_copy.dtype)
|
||||
original = fc.categorical_column_with_hash_bucket('aaa', 10)
|
||||
for column in (original, copy.deepcopy(original)):
|
||||
self.assertEqual('aaa', column.name)
|
||||
self.assertEqual(10, column.hash_bucket_size)
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(10, column._num_buckets)
|
||||
# pylint: enable=protected-access
|
||||
self.assertEqual(dtypes.string, column.dtype)
|
||||
|
||||
def test_parse_config(self):
|
||||
a = fc.categorical_column_with_hash_bucket('aaa', 10)
|
||||
@ -681,14 +685,45 @@ class SparseColumnHashedTest(test.TestCase):
|
||||
|
||||
def test_get_sparse_tensors(self):
|
||||
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
|
||||
wire_tensor = sparse_tensor.SparseTensor(
|
||||
values=['omar', 'stringer', 'marlo'],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
builder = fc._LazyBuilder({'wire': wire_tensor})
|
||||
self.assertEqual(
|
||||
builder.get(hashed_sparse),
|
||||
hashed_sparse._get_sparse_tensors(builder).id_tensor)
|
||||
builder = fc._LazyBuilder({
|
||||
'wire': sparse_tensor.SparseTensor(
|
||||
values=['omar', 'stringer', 'marlo'],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
})
|
||||
id_weight_pair = hashed_sparse._get_sparse_tensors(builder)
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor)
|
||||
|
||||
def test_get_sparse_tensors_dense_input(self):
|
||||
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
|
||||
builder = fc._LazyBuilder({
|
||||
'wire': (('omar', ''), ('stringer', 'marlo'))
|
||||
})
|
||||
id_weight_pair = hashed_sparse._get_sparse_tensors(builder)
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor)
|
||||
|
||||
def test_make_linear_model(self):
|
||||
wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
|
||||
self.assertEqual(4, wire_column._num_buckets)
|
||||
with ops.Graph().as_default():
|
||||
predictions = fc.make_linear_model({
|
||||
wire_column.name: sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=('marlo', 'skywalker', 'omar'),
|
||||
dense_shape=(2, 2))
|
||||
}, (wire_column,))
|
||||
bias = get_linear_model_bias()
|
||||
wire_var = get_linear_model_column_var(wire_column)
|
||||
with _initialized_session():
|
||||
self.assertAllClose((0.,), bias.eval())
|
||||
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
|
||||
self.assertAllClose(((0.,), (0.,)), predictions.eval())
|
||||
wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
|
||||
# 'marlo' -> 3: wire_var[3] = 4
|
||||
# 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
|
||||
self.assertAllClose(((4.,), (6.,)), predictions.eval())
|
||||
|
||||
|
||||
def get_linear_model_bias():
|
||||
@ -1158,5 +1193,350 @@ class MakeInputLayerTest(test.TestCase):
|
||||
self.assertAllClose([[1., 3.]], net2.eval())
|
||||
|
||||
|
||||
class VocabularyCategoricalColumnTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(VocabularyCategoricalColumnTest, self).setUp()
|
||||
|
||||
# Contains ints, Golden State Warriors jersey numbers: 30, 35, 11, 23, 22
|
||||
self._warriors_vocabulary_file_name = test.test_src_dir_path(
|
||||
'python/feature_column/testdata/warriors_vocabulary.txt')
|
||||
self._warriors_vocabulary_size = 5
|
||||
|
||||
# Contains strings, character names from 'The Wire': omar, stringer, marlo
|
||||
self._wire_vocabulary_file_name = test.test_src_dir_path(
|
||||
'python/feature_column/testdata/wire_vocabulary.txt')
|
||||
self._wire_vocabulary_size = 3
|
||||
|
||||
def _assert_sparse_tensor_value(self, expected, actual):
|
||||
self.assertEqual(np.int64, np.array(actual.indices).dtype)
|
||||
self.assertAllEqual(expected.indices, actual.indices)
|
||||
|
||||
self.assertEqual(
|
||||
np.array(expected.values).dtype, np.array(actual.values).dtype)
|
||||
self.assertAllEqual(expected.values, actual.values)
|
||||
|
||||
self.assertEqual(np.int64, np.array(actual.dense_shape).dtype)
|
||||
self.assertAllEqual(expected.dense_shape, actual.dense_shape)
|
||||
|
||||
def test_defaults(self):
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
|
||||
self.assertEqual('aaa', column.name)
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(3, column._num_buckets)
|
||||
self.assertEqual({
|
||||
'aaa': parsing_ops.VarLenFeature(dtypes.string)
|
||||
}, column._parse_example_config)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def test_all_constructor_args(self):
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
|
||||
num_oov_buckets=4, dtype=dtypes.int32)
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(7, column._num_buckets)
|
||||
self.assertEqual({
|
||||
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
||||
}, column._parse_example_config)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def test_deep_copy(self):
|
||||
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
||||
original = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
|
||||
num_oov_buckets=4, dtype=dtypes.int32)
|
||||
for column in (original, copy.deepcopy(original)):
|
||||
self.assertEqual('aaa', column.name)
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(7, column._num_buckets)
|
||||
self.assertEqual({
|
||||
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
|
||||
}, column._parse_example_config)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def test_vocabulary_file_none(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'):
|
||||
fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa', vocabulary_file=None, vocabulary_size=3)
|
||||
|
||||
def test_vocabulary_file_empty_string(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'):
|
||||
fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa', vocabulary_file='', vocabulary_size=3)
|
||||
|
||||
def test_invalid_vocabulary_file(self):
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa', vocabulary_file='file_does_not_exist', vocabulary_size=10)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=('marlo', 'skywalker', 'omar'),
|
||||
dense_shape=(2, 2))
|
||||
# pylint: disable=protected-access
|
||||
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'):
|
||||
with self.test_session():
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
def test_invalid_vocabulary_size(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
|
||||
fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
|
||||
vocabulary_size=None)
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
|
||||
fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
|
||||
vocabulary_size=-1)
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
|
||||
fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
|
||||
vocabulary_size=0)
|
||||
|
||||
def test_too_large_vocabulary_size(self):
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa',
|
||||
vocabulary_file=self._wire_vocabulary_file_name,
|
||||
vocabulary_size=self._wire_vocabulary_size + 1)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=('marlo', 'skywalker', 'omar'),
|
||||
dense_shape=(2, 2))
|
||||
# pylint: disable=protected-access
|
||||
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'):
|
||||
with self.test_session():
|
||||
data_flow_ops.tables_initializer().run()
|
||||
|
||||
def test_invalid_num_oov_buckets(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'):
|
||||
fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa', vocabulary_file='path', vocabulary_size=3,
|
||||
num_oov_buckets=-1)
|
||||
|
||||
def test_invalid_dtype(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid dtype'):
|
||||
fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa', vocabulary_file='path', vocabulary_size=3,
|
||||
dtype=dtypes.float64)
|
||||
|
||||
def test_invalid_buckets_and_default_value(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'both num_oov_buckets and default_value'):
|
||||
fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa',
|
||||
vocabulary_file=self._wire_vocabulary_file_name,
|
||||
vocabulary_size=self._wire_vocabulary_size,
|
||||
num_oov_buckets=100,
|
||||
default_value=2)
|
||||
|
||||
def test_get_sparse_tensors(self):
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa',
|
||||
vocabulary_file=self._wire_vocabulary_file_name,
|
||||
vocabulary_size=self._wire_vocabulary_size)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=('marlo', 'skywalker', 'omar'),
|
||||
dense_shape=(2, 2))
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
self._assert_sparse_tensor_value(
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=inputs.indices,
|
||||
values=np.array((2, -1, 0), dtype=np.int64),
|
||||
dense_shape=inputs.dense_shape),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_dense_input(self):
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa',
|
||||
vocabulary_file=self._wire_vocabulary_file_name,
|
||||
vocabulary_size=self._wire_vocabulary_size)
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
|
||||
'aaa': (('marlo', ''), ('skywalker', 'omar'))
|
||||
}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
self._assert_sparse_tensor_value(
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=np.array((2, -1, 0), dtype=np.int64),
|
||||
dense_shape=(2, 2)),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_default_value_in_vocabulary(self):
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa',
|
||||
vocabulary_file=self._wire_vocabulary_file_name,
|
||||
vocabulary_size=self._wire_vocabulary_size,
|
||||
default_value=2)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=('marlo', 'skywalker', 'omar'),
|
||||
dense_shape=(2, 2))
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
self._assert_sparse_tensor_value(
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=inputs.indices,
|
||||
values=np.array((2, 2, 0), dtype=np.int64),
|
||||
dense_shape=inputs.dense_shape),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_with_oov_buckets(self):
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa',
|
||||
vocabulary_file=self._wire_vocabulary_file_name,
|
||||
vocabulary_size=self._wire_vocabulary_size,
|
||||
num_oov_buckets=100)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1), (1, 2)),
|
||||
values=('marlo', 'skywalker', 'omar', 'heisenberg'),
|
||||
dense_shape=(2, 3))
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
self._assert_sparse_tensor_value(
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=inputs.indices,
|
||||
values=np.array((2, 33, 0, 62), dtype=np.int64),
|
||||
dense_shape=inputs.dense_shape),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_small_vocabulary_size(self):
|
||||
# 'marlo' is the last entry in our vocabulary file, so be setting
|
||||
# `vocabulary_size` to 1 less than number of entries in file, we take
|
||||
# 'marlo' out of the vocabulary.
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa',
|
||||
vocabulary_file=self._wire_vocabulary_file_name,
|
||||
vocabulary_size=self._wire_vocabulary_size - 1)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=('marlo', 'skywalker', 'omar'),
|
||||
dense_shape=(2, 2))
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
self._assert_sparse_tensor_value(
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=inputs.indices,
|
||||
values=np.array((-1, -1, 0), dtype=np.int64),
|
||||
dense_shape=inputs.dense_shape),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_int32(self):
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa',
|
||||
vocabulary_file=self._warriors_vocabulary_file_name,
|
||||
vocabulary_size=self._warriors_vocabulary_size,
|
||||
dtype=dtypes.int32)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1), (2, 2)),
|
||||
values=(11, 100, 30, 22),
|
||||
dense_shape=(3, 3))
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
self._assert_sparse_tensor_value(
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=inputs.indices,
|
||||
values=np.array((2, -1, 0, 4), dtype=np.int64),
|
||||
dense_shape=inputs.dense_shape),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_int32_dense_input(self):
|
||||
default_value = -100
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa',
|
||||
vocabulary_file=self._warriors_vocabulary_file_name,
|
||||
vocabulary_size=self._warriors_vocabulary_size,
|
||||
dtype=dtypes.int32,
|
||||
default_value=default_value)
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
|
||||
'aaa': ((11, -1, -1), (100, 30, -1), (-1, -1, 22))
|
||||
}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
self._assert_sparse_tensor_value(
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1), (2, 2)),
|
||||
values=np.array((2, default_value, 0, 4), dtype=np.int64),
|
||||
dense_shape=(3, 3)),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_int32_with_oov_buckets(self):
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa',
|
||||
vocabulary_file=self._warriors_vocabulary_file_name,
|
||||
vocabulary_size=self._warriors_vocabulary_size,
|
||||
dtype=dtypes.int32,
|
||||
num_oov_buckets=100)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1), (2, 2)),
|
||||
values=(11, 100, 30, 22),
|
||||
dense_shape=(3, 3))
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
self._assert_sparse_tensor_value(
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=inputs.indices,
|
||||
values=np.array((2, 60, 0, 4), dtype=np.int64),
|
||||
dense_shape=inputs.dense_shape),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_make_linear_model(self):
|
||||
wire_column = fc.categorical_column_with_vocabulary_file(
|
||||
key='wire',
|
||||
vocabulary_file=self._wire_vocabulary_file_name,
|
||||
vocabulary_size=self._wire_vocabulary_size,
|
||||
num_oov_buckets=1)
|
||||
self.assertEqual(4, wire_column._num_buckets)
|
||||
with ops.Graph().as_default():
|
||||
predictions = fc.make_linear_model({
|
||||
wire_column.name: sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=('marlo', 'skywalker', 'omar'),
|
||||
dense_shape=(2, 2))
|
||||
}, (wire_column,))
|
||||
bias = get_linear_model_bias()
|
||||
wire_var = get_linear_model_column_var(wire_column)
|
||||
with _initialized_session():
|
||||
self.assertAllClose((0.,), bias.eval())
|
||||
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
|
||||
self.assertAllClose(((0.,), (0.,)), predictions.eval())
|
||||
wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
|
||||
# 'marlo' -> 2: wire_var[2] = 3
|
||||
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
|
||||
self.assertAllClose(((3.,), (5.,)), predictions.eval())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Lookup table Operations."""
|
||||
# pylint: disable=g-bad-name
|
||||
"""Lookup table operations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -608,7 +608,7 @@ class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
FastHashSpec = HasherSpec("fasthash", None)
|
||||
FastHashSpec = HasherSpec("fasthash", None) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StrongHashSpec(HasherSpec):
|
5
tensorflow/python/feature_column/testdata/warriors_vocabulary.txt
vendored
Normal file
5
tensorflow/python/feature_column/testdata/warriors_vocabulary.txt
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
30
|
||||
35
|
||||
11
|
||||
23
|
||||
22
|
3
tensorflow/python/feature_column/testdata/wire_vocabulary.txt
vendored
Normal file
3
tensorflow/python/feature_column/testdata/wire_vocabulary.txt
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
omar
|
||||
stringer
|
||||
marlo
|
@ -45,6 +45,7 @@ BLACKLIST = [
|
||||
"//tensorflow/python:compare_test_proto_py",
|
||||
"//tensorflow/core:image_testdata",
|
||||
"//tensorflow/core/kernels/cloud:bigquery_reader_ops",
|
||||
"//tensorflow/python/feature_column:vocabulary_testdata",
|
||||
"//tensorflow/python:framework/test_file_system.so",
|
||||
# contrib
|
||||
"//tensorflow/contrib/session_bundle:session_bundle_half_plus_two",
|
||||
|
Loading…
Reference in New Issue
Block a user