Add embedding_column
.
Add some pydoc to other columns. Add weight_collections tests for all columns. Rename embedding 'weights' -> 'embedding_weights', linear 'weight' -> 'weights', and linear 'bias_weight' -> 'bias_weights'. PiperOrigin-RevId: 155648428
This commit is contained in:
parent
8bda9d06f8
commit
64698dd7bb
tensorflow/python/feature_column
@ -39,6 +39,7 @@ py_library(
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
@ -48,6 +49,9 @@ py_library(
|
||||
filegroup(
|
||||
name = "vocabulary_testdata",
|
||||
srcs = [
|
||||
"testdata/embedding.ckpt.data-00000-of-00001",
|
||||
"testdata/embedding.ckpt.index",
|
||||
"testdata/embedding.ckpt.meta",
|
||||
"testdata/warriors_vocabulary.txt",
|
||||
"testdata/wire_vocabulary.txt",
|
||||
],
|
||||
|
@ -123,6 +123,7 @@ from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
@ -144,6 +145,7 @@ from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_utils
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -304,7 +306,7 @@ def make_linear_model(features,
|
||||
predictions_no_bias = math_ops.add_n(
|
||||
weigthed_sums, name='weighted_sum_no_bias')
|
||||
bias = variable_scope.get_variable(
|
||||
'bias_weight',
|
||||
'bias_weights',
|
||||
shape=[units],
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=trainable,
|
||||
@ -416,6 +418,88 @@ def make_parse_example_spec(feature_columns):
|
||||
return result
|
||||
|
||||
|
||||
def embedding_column(
|
||||
categorical_column, dimension, combiner='mean', initializer=None,
|
||||
ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None,
|
||||
trainable=True):
|
||||
"""`_DenseColumn` that converts from sparse, categorical input.
|
||||
|
||||
Use this when your inputs are sparse, but you want to convert them to a dense
|
||||
representation (e.g., to feed to a DNN).
|
||||
|
||||
Inputs must be `SparseTensor` by way of the provided
|
||||
`categorical_column._get_sparse_tensors`.
|
||||
|
||||
Any of the `categorical_column_*` columns can be provided as input. Here is an
|
||||
example embedding of an identity column for a DNN model:
|
||||
|
||||
```python
|
||||
video_id = categorical_column_with_identity(
|
||||
key='video_id', num_buckets=1000000, default_value=0)
|
||||
columns = [embedding_column(video_id, 9),...]
|
||||
features = tf.parse_example(..., features=parse_example_spec(columns))
|
||||
dense_tensor = make_input_layer(features, columns)
|
||||
```
|
||||
|
||||
Args:
|
||||
categorical_column: A `_CategoricalColumn` created by a
|
||||
`categorical_column_with_*` function. This column produces the sparse IDs
|
||||
that are inputs to the embedding lookup.
|
||||
dimension: An integer specifying dimension of the embedding, must be > 0.
|
||||
combiner: A string specifying how to reduce if there are multiple entries
|
||||
in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
|
||||
'mean' the default. 'sqrtn' often achieves good accuracy, in particular
|
||||
with bag-of-words columns. Each of this can be thought as example level
|
||||
normalizations on the column. For more information, see
|
||||
`tf.embedding_lookup_sparse`.
|
||||
initializer: A variable initializer function to be used in embedding
|
||||
variable initialization. If not specified, defaults to
|
||||
`tf.truncated_normal_initializer` with mean `0.0` and standard deviation
|
||||
`1/sqrt(categorical_column._num_buckets)`.
|
||||
ckpt_to_load_from: String representing checkpoint name/pattern fromwhich to
|
||||
restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
|
||||
tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
|
||||
which to restore the column weights. Required if `ckpt_to_load_from` is
|
||||
not `None`.
|
||||
max_norm: If not `None`, embedding values are l2-normalized to this value.
|
||||
trainable: Whether or not the embedding is trainable. Default is True.
|
||||
|
||||
Returns:
|
||||
`_DenseColumn` that converts from sparse input.
|
||||
|
||||
Raises:
|
||||
ValueError: if `dimension` not > 0.
|
||||
ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
|
||||
is specified.
|
||||
ValueError: if `initializer` is specified and is not callable.
|
||||
"""
|
||||
if (dimension is None) or (dimension < 1):
|
||||
raise ValueError('Invalid dimension {}.'.format(dimension))
|
||||
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
|
||||
raise ValueError('Must specify both `ckpt_to_load_from` and '
|
||||
'`tensor_name_in_ckpt` or none of them.')
|
||||
|
||||
if (initializer is not None) and (not callable(initializer)):
|
||||
raise ValueError('initializer must be callable if specified. '
|
||||
'Embedding of column_name: {}'.format(
|
||||
categorical_column.name))
|
||||
if initializer is None:
|
||||
# pylint: disable=protected-access
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
mean=0.0, stddev=1 / math.sqrt(dimension))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return _EmbeddingColumn(
|
||||
categorical_column=categorical_column,
|
||||
dimension=dimension,
|
||||
combiner=combiner,
|
||||
initializer=initializer,
|
||||
ckpt_to_load_from=ckpt_to_load_from,
|
||||
tensor_name_in_ckpt=tensor_name_in_ckpt,
|
||||
max_norm=max_norm,
|
||||
trainable=trainable)
|
||||
|
||||
|
||||
def numeric_column(key,
|
||||
shape=(1,),
|
||||
default_value=None,
|
||||
@ -568,7 +652,7 @@ def categorical_column_with_hash_bucket(key,
|
||||
want to distribute your inputs into a finite number of buckets by hashing.
|
||||
output_id = Hash(input_feature_string) % bucket_size
|
||||
|
||||
`features[key]` are either `Tensor` or `SparseTensor`. If `Tensor`, missing
|
||||
`features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing
|
||||
values can be represented by `-1` for int and `''` for string. Note that these
|
||||
values are independent of the `default_value` argument.
|
||||
|
||||
@ -625,7 +709,7 @@ def categorical_column_with_vocabulary_file(
|
||||
`num_oov_buckets` and `default_value` to specify how to include
|
||||
out-of-vocabulary values.
|
||||
|
||||
`features[key]` are either `Tensor` or `SparseTensor`. If `Tensor`, missing
|
||||
`features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing
|
||||
values can be represented by `-1` for int and `''` for string. Note that these
|
||||
values are independent of the `default_value` argument.
|
||||
|
||||
@ -693,7 +777,6 @@ def categorical_column_with_vocabulary_file(
|
||||
if not vocabulary_file:
|
||||
raise ValueError('Missing vocabulary_file in {}.'.format(key))
|
||||
# `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
|
||||
# TODO(ptucker): Should we fail for vocabulary_size==1?
|
||||
if (vocabulary_size is None) or (vocabulary_size < 1):
|
||||
raise ValueError('Invalid vocabulary_size in {}.'.format(key))
|
||||
if num_oov_buckets:
|
||||
@ -719,14 +802,14 @@ def categorical_column_with_vocabulary_list(
|
||||
"""A `_CategoricalColumn` with in-memory vocabulary.
|
||||
|
||||
Logic for feature f is:
|
||||
id = f in vocabulary_list ? vocabulary_list.index(f) : default_value
|
||||
id = f in vocabulary_list ? vocabulary_list.index_of(f) : default_value
|
||||
|
||||
Use this when your inputs are in string or integer format, and you have an
|
||||
in-memory vocabulary mapping each value to an integer ID. By default,
|
||||
out-of-vocabulary values are ignored. Use `default_value` to specify how to
|
||||
include out-of-vocabulary values.
|
||||
|
||||
`features[key]` are either `Tensor` or `SparseTensor`. If `Tensor`, missing
|
||||
`features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing
|
||||
values can be represented by `-1` for int and `''` for string. Note that these
|
||||
values are independent of the `default_value` argument.
|
||||
|
||||
@ -795,11 +878,16 @@ def categorical_column_with_vocabulary_list(
|
||||
def categorical_column_with_identity(key, num_buckets, default_value=None):
|
||||
"""A `_CategoricalColumn` that returns identity values.
|
||||
|
||||
Use this when your inputs are integers in the range `[0, num_buckets)`. Values
|
||||
outside this range will result in `default_value` if specified, otherwise it
|
||||
will fail.
|
||||
Use this when your inputs are integers in the range `[0, num_buckets)`, and
|
||||
you want to use the input value itself as the categorical ID. Values outside
|
||||
this range will result in `default_value` if specified, otherwise it will
|
||||
fail.
|
||||
|
||||
`features[key]` are either `Tensor` or `SparseTensor`. If `Tensor`, missing
|
||||
Typically, this is used for contiguous ranges of integer indexes, but
|
||||
it doesn't have to be. This might be inefficient, however, if many of IDs
|
||||
are unused. Consider `categorical_column_with_hash_bucket` in that case.
|
||||
|
||||
`features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing
|
||||
values can be represented by `-1` for int and `''` for string. Note that these
|
||||
values are independent of the `default_value` argument.
|
||||
|
||||
@ -960,15 +1048,15 @@ class _DenseColumn(_FeatureColumn):
|
||||
|
||||
@abc.abstractproperty
|
||||
def _variable_shape(self):
|
||||
"""Returns a `TensorShape` representing the shape of the dense `Tensor`."""
|
||||
"""`TensorShape` of `_get_dense_tensor`, without batch dimension."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
"""Returns a `Tensor`.
|
||||
|
||||
The output of this function will be used by model-buildier-functions. For
|
||||
example the pseudo code of `make_input_layer` will be like that:
|
||||
The output of this function will be used by model-builder-functions. For
|
||||
example the pseudo code of `make_input_layer` will be like:
|
||||
|
||||
```python
|
||||
def make_input_layer(features, feature_columns, ...):
|
||||
@ -982,6 +1070,9 @@ class _DenseColumn(_FeatureColumn):
|
||||
will be created) are added.
|
||||
trainable: If `True` also add variables to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see ${tf.Variable}).
|
||||
|
||||
Returns:
|
||||
`Tensor` of shape [batch_size] + `_variable_shape`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -997,7 +1088,7 @@ def _create_dense_column_weighted_sum(
|
||||
batch_size = array_ops.shape(tensor)[0]
|
||||
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
|
||||
weight = variable_scope.get_variable(
|
||||
name='weight',
|
||||
name='weights',
|
||||
shape=[num_elements, units],
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=trainable,
|
||||
@ -1060,7 +1151,7 @@ def _create_categorical_column_weighted_sum(
|
||||
trainable=trainable)
|
||||
weight = variable_scope.get_variable(
|
||||
name='weight',
|
||||
shape=[column._num_buckets, units], # pylint: disable=protected-access
|
||||
shape=(column._num_buckets, units), # pylint: disable=protected-access
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=trainable,
|
||||
collections=weight_collections)
|
||||
@ -1365,6 +1456,74 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn,
|
||||
return _CategoricalColumn.IdWeightPair(sparse_tensor, None)
|
||||
|
||||
|
||||
class _EmbeddingColumn(
|
||||
_DenseColumn,
|
||||
collections.namedtuple('_EmbeddingColumn', (
|
||||
'categorical_column', 'dimension', 'combiner', 'initializer',
|
||||
'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable'
|
||||
))):
|
||||
"""See `_embedding_column`."""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
if not hasattr(self, '_name'):
|
||||
self._name = '{}_embedding'.format(self.categorical_column.name)
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def _parse_example_config(self):
|
||||
# pylint: disable=protected-access
|
||||
return self.categorical_column._parse_example_config
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _transform_feature(self, inputs):
|
||||
return inputs.get(self.categorical_column)
|
||||
|
||||
@property
|
||||
def _variable_shape(self):
|
||||
if not hasattr(self, '_shape'):
|
||||
self._shape = tensor_shape.vector(self.dimension)
|
||||
return self._shape
|
||||
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
# Get sparse IDs and weights.
|
||||
# pylint: disable=protected-access
|
||||
sparse_tensors = self.categorical_column._get_sparse_tensors(
|
||||
inputs, weight_collections=weight_collections, trainable=trainable)
|
||||
# pylint: enable=protected-access
|
||||
sparse_ids = sparse_tensors.id_tensor
|
||||
sparse_weights = sparse_tensors.weight_tensor
|
||||
|
||||
# Create embedding weight, and restore from checkpoint if necessary.
|
||||
embedding_weights = variable_scope.get_variable(
|
||||
name='embedding_weights',
|
||||
# pylint: disable=protected-access
|
||||
shape=(self.categorical_column._num_buckets, self.dimension),
|
||||
# pylint: enable=protected-access
|
||||
dtype=dtypes.float32,
|
||||
initializer=self.initializer,
|
||||
trainable=self.trainable and trainable,
|
||||
collections=weight_collections)
|
||||
if self.ckpt_to_load_from is not None:
|
||||
to_restore = embedding_weights
|
||||
if isinstance(to_restore, variables.PartitionedVariable):
|
||||
# pylint: disable=protected-access
|
||||
to_restore = to_restore._get_variable_list()
|
||||
# pylint: enable=protected-access
|
||||
checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
|
||||
self.tensor_name_in_ckpt: to_restore
|
||||
})
|
||||
|
||||
# Return embedding lookup result.
|
||||
return _safe_embedding_lookup_sparse(
|
||||
embedding_weights=embedding_weights,
|
||||
sparse_ids=sparse_ids,
|
||||
sparse_weights=sparse_weights,
|
||||
combiner=self.combiner,
|
||||
name='%s_weights' % self.name,
|
||||
max_norm=self.max_norm)
|
||||
|
||||
|
||||
def _create_tuple(shape, value):
|
||||
"""Returns a tuple with given shape and filled with value."""
|
||||
if shape:
|
||||
@ -1689,7 +1848,7 @@ class _IdentityCategoricalColumn(
|
||||
def _safe_embedding_lookup_sparse(embedding_weights,
|
||||
sparse_ids,
|
||||
sparse_weights=None,
|
||||
combiner=None,
|
||||
combiner='mean',
|
||||
default_id=None,
|
||||
name=None,
|
||||
partition_strategy='div',
|
||||
@ -1727,7 +1886,7 @@ def _safe_embedding_lookup_sparse(embedding_weights,
|
||||
name: A name for this operation (optional).
|
||||
partition_strategy: A string specifying the partitioning strategy.
|
||||
Currently `"div"` and `"mod"` are supported. Default is `"div"`.
|
||||
max_norm: If not None, all embeddings are l2-normalized to max_norm before
|
||||
max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
|
||||
combining.
|
||||
|
||||
|
||||
@ -1737,10 +1896,6 @@ def _safe_embedding_lookup_sparse(embedding_weights,
|
||||
Raises:
|
||||
ValueError: if `embedding_weights` is empty.
|
||||
"""
|
||||
if combiner is None:
|
||||
logging.warn('The default value of combiner will change from \"mean\" '
|
||||
'to \"sqrtn\" after 2016/11/01.')
|
||||
combiner = 'mean'
|
||||
if embedding_weights is None:
|
||||
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
|
||||
if isinstance(embedding_weights, variables.PartitionedVariable):
|
||||
@ -1751,8 +1906,6 @@ def _safe_embedding_lookup_sparse(embedding_weights,
|
||||
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
|
||||
|
||||
dtype = sparse_weights.dtype if sparse_weights is not None else None
|
||||
if isinstance(embedding_weights, variables.PartitionedVariable):
|
||||
embedding_weights = list(embedding_weights)
|
||||
embedding_weights = [
|
||||
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
|
||||
]
|
||||
|
@ -577,7 +577,6 @@ class HashedCategoricalColumnTest(test.TestCase):
|
||||
fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.float32)
|
||||
|
||||
def test_deep_copy(self):
|
||||
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
||||
original = fc.categorical_column_with_hash_bucket('aaa', 10)
|
||||
for column in (original, copy.deepcopy(original)):
|
||||
self.assertEqual('aaa', column.name)
|
||||
@ -692,6 +691,19 @@ class HashedCategoricalColumnTest(test.TestCase):
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor)
|
||||
|
||||
def test_get_sparse_tensors_weight_collections(self):
|
||||
column = fc.categorical_column_with_hash_bucket('aaa', 10)
|
||||
inputs = sparse_tensor.SparseTensor(
|
||||
values=['omar', 'stringer', 'marlo'],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}), weight_collections=('my_weights',))
|
||||
|
||||
self.assertItemsEqual(
|
||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
||||
|
||||
def test_get_sparse_tensors_dense_input(self):
|
||||
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
|
||||
builder = fc._LazyBuilder({
|
||||
@ -725,7 +737,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
||||
|
||||
def get_linear_model_bias():
|
||||
with variable_scope.variable_scope('make_linear_model', reuse=True):
|
||||
return variable_scope.get_variable('bias_weight')
|
||||
return variable_scope.get_variable('bias_weights')
|
||||
|
||||
|
||||
def get_linear_model_column_var(column):
|
||||
@ -885,8 +897,8 @@ class MakeLinearModelTest(test.TestCase):
|
||||
bias = get_linear_model_bias()
|
||||
price_var = get_linear_model_column_var(price)
|
||||
with _initialized_session() as sess:
|
||||
self.assertAllClose([0., 0., 0.], bias.eval())
|
||||
self.assertAllClose([[0., 0., 0.]], price_var.eval())
|
||||
self.assertAllClose(np.zeros((3,)), bias.eval())
|
||||
self.assertAllClose(np.zeros((1, 3)), price_var.eval())
|
||||
sess.run(price_var.assign([[10., 100., 1000.]]))
|
||||
sess.run(bias.assign([5., 6., 7.]))
|
||||
self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
|
||||
@ -904,8 +916,8 @@ class MakeLinearModelTest(test.TestCase):
|
||||
bias = get_linear_model_bias()
|
||||
wire_cast_var = get_linear_model_column_var(wire_cast)
|
||||
with _initialized_session() as sess:
|
||||
self.assertAllClose([0., 0., 0.], bias.eval())
|
||||
self.assertAllClose([[0.] * 3] * 4, wire_cast_var.eval())
|
||||
self.assertAllClose(np.zeros((3,)), bias.eval())
|
||||
self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
|
||||
sess.run(
|
||||
wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.], [
|
||||
1000., 1100., 1200.
|
||||
@ -950,8 +962,8 @@ class MakeLinearModelTest(test.TestCase):
|
||||
bias = get_linear_model_bias()
|
||||
price_var = get_linear_model_column_var(price)
|
||||
with _initialized_session() as sess:
|
||||
self.assertAllClose([0., 0., 0.], bias.eval())
|
||||
self.assertAllClose([[0., 0., 0.], [0., 0., 0.]], price_var.eval())
|
||||
self.assertAllClose(np.zeros((3,)), bias.eval())
|
||||
self.assertAllClose(np.zeros((2, 3)), price_var.eval())
|
||||
sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
|
||||
sess.run(bias.assign([2., 3., 4.]))
|
||||
self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
|
||||
@ -1327,7 +1339,6 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
||||
}, column._parse_example_config)
|
||||
|
||||
def test_deep_copy(self):
|
||||
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
||||
original = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
|
||||
num_oov_buckets=4, dtype=dtypes.int32)
|
||||
@ -1476,6 +1487,22 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
||||
dense_shape=inputs.dense_shape),
|
||||
id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_weight_collections(self):
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa',
|
||||
vocabulary_file=self._wire_vocabulary_file_name,
|
||||
vocabulary_size=self._wire_vocabulary_size)
|
||||
inputs = sparse_tensor.SparseTensor(
|
||||
values=['omar', 'stringer', 'marlo'],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}), weight_collections=('my_weights',))
|
||||
|
||||
self.assertItemsEqual(
|
||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
||||
|
||||
def test_get_sparse_tensors_dense_input(self):
|
||||
column = fc.categorical_column_with_vocabulary_file(
|
||||
key='aaa',
|
||||
@ -1684,7 +1711,6 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
||||
}, column._parse_example_config)
|
||||
|
||||
def test_deep_copy(self):
|
||||
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
||||
original = fc.categorical_column_with_vocabulary_list(
|
||||
key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32)
|
||||
for column in (original, copy.deepcopy(original)):
|
||||
@ -1778,6 +1804,21 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
||||
dense_shape=inputs.dense_shape),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_weight_collections(self):
|
||||
column = fc.categorical_column_with_vocabulary_list(
|
||||
key='aaa',
|
||||
vocabulary_list=('omar', 'stringer', 'marlo'))
|
||||
inputs = sparse_tensor.SparseTensor(
|
||||
values=['omar', 'stringer', 'marlo'],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}), weight_collections=('my_weights',))
|
||||
|
||||
self.assertItemsEqual(
|
||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
||||
|
||||
def test_get_sparse_tensors_dense_input(self):
|
||||
column = fc.categorical_column_with_vocabulary_list(
|
||||
key='aaa',
|
||||
@ -1894,7 +1935,6 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
||||
}, column._parse_example_config)
|
||||
|
||||
def test_deep_copy(self):
|
||||
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
||||
original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||
for column in (original, copy.deepcopy(original)):
|
||||
self.assertEqual('aaa', column.name)
|
||||
@ -1948,6 +1988,19 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
||||
dense_shape=inputs.dense_shape),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_weight_collections(self):
|
||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=(0, 1, 0),
|
||||
dense_shape=(2, 2))
|
||||
column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}), weight_collections=('my_weights',))
|
||||
|
||||
self.assertItemsEqual(
|
||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||
self.assertItemsEqual([], ops.get_collection('my_weights'))
|
||||
|
||||
def test_get_sparse_tensors_dense_input(self):
|
||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
|
||||
@ -2221,5 +2274,553 @@ class IndicatorColumnTest(test.TestCase):
|
||||
self.assertAllClose([[0., 1., 1., 0.]], net.eval())
|
||||
|
||||
|
||||
class EmbeddingColumnTest(test.TestCase):
|
||||
|
||||
def test_defaults(self):
|
||||
categorical_column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=3)
|
||||
embedding_dimension = 2
|
||||
embedding_column = fc.embedding_column(
|
||||
categorical_column, dimension=embedding_dimension)
|
||||
self.assertIs(categorical_column, embedding_column.categorical_column)
|
||||
self.assertEqual(embedding_dimension, embedding_column.dimension)
|
||||
self.assertEqual('mean', embedding_column.combiner)
|
||||
self.assertIsNotNone(embedding_column.initializer)
|
||||
self.assertIsNone(embedding_column.ckpt_to_load_from)
|
||||
self.assertIsNone(embedding_column.tensor_name_in_ckpt)
|
||||
self.assertIsNone(embedding_column.max_norm)
|
||||
self.assertTrue(embedding_column.trainable)
|
||||
self.assertEqual('aaa_embedding', embedding_column.name)
|
||||
self.assertEqual(
|
||||
(embedding_dimension,), embedding_column._variable_shape)
|
||||
self.assertEqual({
|
||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||
}, embedding_column._parse_example_config)
|
||||
|
||||
def test_all_constructor_args(self):
|
||||
categorical_column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=3)
|
||||
embedding_dimension = 2
|
||||
embedding_column = fc.embedding_column(
|
||||
categorical_column, dimension=embedding_dimension,
|
||||
combiner='my_combiner', initializer=lambda: 'my_initializer',
|
||||
ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor',
|
||||
max_norm=42., trainable=False)
|
||||
self.assertIs(categorical_column, embedding_column.categorical_column)
|
||||
self.assertEqual(embedding_dimension, embedding_column.dimension)
|
||||
self.assertEqual('my_combiner', embedding_column.combiner)
|
||||
self.assertEqual('my_initializer', embedding_column.initializer())
|
||||
self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from)
|
||||
self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt)
|
||||
self.assertEqual(42., embedding_column.max_norm)
|
||||
self.assertFalse(embedding_column.trainable)
|
||||
self.assertEqual('aaa_embedding', embedding_column.name)
|
||||
self.assertEqual(
|
||||
(embedding_dimension,), embedding_column._variable_shape)
|
||||
self.assertEqual({
|
||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||
}, embedding_column._parse_example_config)
|
||||
|
||||
def test_deep_copy(self):
|
||||
categorical_column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=3)
|
||||
embedding_dimension = 2
|
||||
original = fc.embedding_column(
|
||||
categorical_column, dimension=embedding_dimension,
|
||||
combiner='my_combiner', initializer=lambda: 'my_initializer',
|
||||
ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor',
|
||||
max_norm=42., trainable=False)
|
||||
for embedding_column in (original, copy.deepcopy(original)):
|
||||
self.assertEqual('aaa', embedding_column.categorical_column.name)
|
||||
self.assertEqual(3, embedding_column.categorical_column._num_buckets)
|
||||
self.assertEqual({
|
||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||
}, embedding_column.categorical_column._parse_example_config)
|
||||
|
||||
self.assertEqual(embedding_dimension, embedding_column.dimension)
|
||||
self.assertEqual('my_combiner', embedding_column.combiner)
|
||||
self.assertEqual('my_initializer', embedding_column.initializer())
|
||||
self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from)
|
||||
self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt)
|
||||
self.assertEqual(42., embedding_column.max_norm)
|
||||
self.assertFalse(embedding_column.trainable)
|
||||
self.assertEqual('aaa_embedding', embedding_column.name)
|
||||
self.assertEqual(
|
||||
(embedding_dimension,), embedding_column._variable_shape)
|
||||
self.assertEqual({
|
||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||
}, embedding_column._parse_example_config)
|
||||
|
||||
def test_invalid_initializer(self):
|
||||
categorical_column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=3)
|
||||
with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
|
||||
fc.embedding_column(categorical_column, dimension=2, initializer='not_fn')
|
||||
|
||||
def test_get_dense_tensor(self):
|
||||
# Inputs.
|
||||
vocabulary_size = 3
|
||||
sparse_input = sparse_tensor.SparseTensorValue(
|
||||
# example 0, ids [2]
|
||||
# example 1, ids [0, 1]
|
||||
# example 2, ids []
|
||||
# example 3, ids [1]
|
||||
indices=((0, 0), (1, 0), (1, 4), (3, 0)),
|
||||
values=(2, 0, 1, 1),
|
||||
dense_shape=(4, 5))
|
||||
|
||||
# Embedding variable.
|
||||
embedding_dimension = 2
|
||||
embedding_values = (
|
||||
(1., 2.), # id 0
|
||||
(3., 5.), # id 1
|
||||
(7., 11.) # id 2
|
||||
)
|
||||
def _initializer(shape, dtype, partition_info):
|
||||
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
||||
self.assertEqual(dtypes.float32, dtype)
|
||||
self.assertIsNone(partition_info)
|
||||
return embedding_values
|
||||
|
||||
# Expected lookup result, using combiner='mean'.
|
||||
expected_lookups = (
|
||||
# example 0, ids [2], embedding = [7, 11]
|
||||
(7., 11.),
|
||||
# example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
|
||||
(2., 3.5),
|
||||
# example 2, ids [], embedding = [0, 0]
|
||||
(0., 0.),
|
||||
# example 3, ids [1], embedding = [3, 5]
|
||||
(3., 5.),
|
||||
)
|
||||
|
||||
# Build columns.
|
||||
categorical_column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=vocabulary_size)
|
||||
embedding_column = fc.embedding_column(
|
||||
categorical_column, dimension=embedding_dimension,
|
||||
initializer=_initializer)
|
||||
|
||||
# Provide sparse input and get dense result.
|
||||
embedding_lookup = embedding_column._get_dense_tensor(fc._LazyBuilder({
|
||||
'aaa': sparse_input
|
||||
}))
|
||||
|
||||
# Assert expected embedding variable and lookups.
|
||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.assertItemsEqual(
|
||||
('embedding_weights:0',), tuple([v.name for v in global_vars]))
|
||||
with _initialized_session():
|
||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||
self.assertAllEqual(expected_lookups, embedding_lookup.eval())
|
||||
|
||||
def test_get_dense_tensor_3d(self):
|
||||
# Inputs.
|
||||
vocabulary_size = 4
|
||||
sparse_input = sparse_tensor.SparseTensorValue(
|
||||
# example 0, ids [2]
|
||||
# example 1, ids [0, 1]
|
||||
# example 2, ids []
|
||||
# example 3, ids [1]
|
||||
indices=((0, 0, 0), (1, 1, 0), (1, 1, 4), (3, 0, 0), (3, 1, 2)),
|
||||
values=(2, 0, 1, 1, 2),
|
||||
dense_shape=(4, 2, 5))
|
||||
|
||||
# Embedding variable.
|
||||
embedding_dimension = 3
|
||||
embedding_values = (
|
||||
(1., 2., 4.), # id 0
|
||||
(3., 5., 1.), # id 1
|
||||
(7., 11., 2.), # id 2
|
||||
(2., 7., 12.) # id 3
|
||||
)
|
||||
def _initializer(shape, dtype, partition_info):
|
||||
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
||||
self.assertEqual(dtypes.float32, dtype)
|
||||
self.assertIsNone(partition_info)
|
||||
return embedding_values
|
||||
|
||||
# Expected lookup result, using combiner='mean'.
|
||||
expected_lookups = (
|
||||
# example 0, ids [[2], []], embedding = [[7, 11, 2], [0, 0, 0]]
|
||||
((7., 11., 2.), (0., 0., 0.)),
|
||||
# example 1, ids [[], [0, 1]], embedding
|
||||
# = mean([[], [1, 2, 4] + [3, 5, 1]]) = [[0, 0, 0], [2, 3.5, 2.5]]
|
||||
((0., 0., 0.), (2., 3.5, 2.5)),
|
||||
# example 2, ids [[], []], embedding = [[0, 0, 0], [0, 0, 0]]
|
||||
((0., 0., 0.), (0., 0., 0.)),
|
||||
# example 3, ids [[1], [2]], embedding = [[3, 5, 1], [7, 11, 2]]
|
||||
((3., 5., 1.), (7., 11., 2.)),
|
||||
)
|
||||
|
||||
# Build columns.
|
||||
categorical_column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=vocabulary_size)
|
||||
embedding_column = fc.embedding_column(
|
||||
categorical_column, dimension=embedding_dimension,
|
||||
initializer=_initializer)
|
||||
|
||||
# Provide sparse input and get dense result.
|
||||
embedding_lookup = embedding_column._get_dense_tensor(fc._LazyBuilder({
|
||||
'aaa': sparse_input
|
||||
}))
|
||||
|
||||
# Assert expected embedding variable and lookups.
|
||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.assertItemsEqual(
|
||||
('embedding_weights:0',), tuple([v.name for v in global_vars]))
|
||||
with _initialized_session():
|
||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||
self.assertAllEqual(expected_lookups, embedding_lookup.eval())
|
||||
|
||||
def test_get_dense_tensor_weight_collections(self):
|
||||
sparse_input = sparse_tensor.SparseTensorValue(
|
||||
# example 0, ids [2]
|
||||
# example 1, ids [0, 1]
|
||||
# example 2, ids []
|
||||
# example 3, ids [1]
|
||||
indices=((0, 0), (1, 0), (1, 4), (3, 0)),
|
||||
values=(2, 0, 1, 1),
|
||||
dense_shape=(4, 5))
|
||||
|
||||
# Build columns.
|
||||
categorical_column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=3)
|
||||
embedding_column = fc.embedding_column(categorical_column, dimension=2)
|
||||
|
||||
# Provide sparse input and get dense result.
|
||||
embedding_column._get_dense_tensor(fc._LazyBuilder({
|
||||
'aaa': sparse_input
|
||||
}), weight_collections=('my_vars',))
|
||||
|
||||
# Assert expected embedding variable and lookups.
|
||||
self.assertItemsEqual(
|
||||
[], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||
my_vars = ops.get_collection('my_vars')
|
||||
self.assertItemsEqual(
|
||||
('embedding_weights:0',), tuple([v.name for v in my_vars]))
|
||||
|
||||
def test_get_dense_tensor_placeholder_inputs(self):
|
||||
# Inputs.
|
||||
vocabulary_size = 3
|
||||
sparse_input = sparse_tensor.SparseTensorValue(
|
||||
# example 0, ids [2]
|
||||
# example 1, ids [0, 1]
|
||||
# example 2, ids []
|
||||
# example 3, ids [1]
|
||||
indices=((0, 0), (1, 0), (1, 4), (3, 0)),
|
||||
values=(2, 0, 1, 1),
|
||||
dense_shape=(4, 5))
|
||||
|
||||
# Embedding variable.
|
||||
embedding_dimension = 2
|
||||
embedding_values = (
|
||||
(1., 2.), # id 0
|
||||
(3., 5.), # id 1
|
||||
(7., 11.) # id 2
|
||||
)
|
||||
def _initializer(shape, dtype, partition_info):
|
||||
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
||||
self.assertEqual(dtypes.float32, dtype)
|
||||
self.assertIsNone(partition_info)
|
||||
return embedding_values
|
||||
|
||||
# Expected lookup result, using combiner='mean'.
|
||||
expected_lookups = (
|
||||
# example 0, ids [2], embedding = [7, 11]
|
||||
(7., 11.),
|
||||
# example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
|
||||
(2., 3.5),
|
||||
# example 2, ids [], embedding = [0, 0]
|
||||
(0., 0.),
|
||||
# example 3, ids [1], embedding = [3, 5]
|
||||
(3., 5.),
|
||||
)
|
||||
|
||||
# Build columns.
|
||||
categorical_column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=vocabulary_size)
|
||||
embedding_column = fc.embedding_column(
|
||||
categorical_column, dimension=embedding_dimension,
|
||||
initializer=_initializer)
|
||||
|
||||
# Provide sparse input and get dense result.
|
||||
input_indices = array_ops.placeholder(dtype=dtypes.int64)
|
||||
input_values = array_ops.placeholder(dtype=dtypes.int64)
|
||||
input_shape = array_ops.placeholder(dtype=dtypes.int64)
|
||||
embedding_lookup = embedding_column._get_dense_tensor(fc._LazyBuilder({
|
||||
'aaa': sparse_tensor.SparseTensorValue(
|
||||
indices=input_indices,
|
||||
values=input_values,
|
||||
dense_shape=input_shape)
|
||||
}))
|
||||
|
||||
# Assert expected embedding variable and lookups.
|
||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.assertItemsEqual(
|
||||
('embedding_weights:0',), tuple([v.name for v in global_vars]))
|
||||
with _initialized_session():
|
||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||
self.assertAllEqual(expected_lookups, embedding_lookup.eval(
|
||||
feed_dict={
|
||||
input_indices: sparse_input.indices,
|
||||
input_values: sparse_input.values,
|
||||
input_shape: sparse_input.dense_shape,
|
||||
}))
|
||||
|
||||
def test_get_dense_tensor_restore_from_ckpt(self):
|
||||
# Inputs.
|
||||
vocabulary_size = 3
|
||||
sparse_input = sparse_tensor.SparseTensorValue(
|
||||
# example 0, ids [2]
|
||||
# example 1, ids [0, 1]
|
||||
# example 2, ids []
|
||||
# example 3, ids [1]
|
||||
indices=((0, 0), (1, 0), (1, 4), (3, 0)),
|
||||
values=(2, 0, 1, 1),
|
||||
dense_shape=(4, 5))
|
||||
|
||||
# Embedding variable. The checkpoint file contains _embedding_values.
|
||||
embedding_dimension = 2
|
||||
embedding_values = (
|
||||
(1., 2.), # id 0
|
||||
(3., 5.), # id 1
|
||||
(7., 11.) # id 2
|
||||
)
|
||||
ckpt_path = test.test_src_dir_path(
|
||||
'python/feature_column/testdata/embedding.ckpt')
|
||||
ckpt_tensor = 'my_embedding'
|
||||
|
||||
# Expected lookup result, using combiner='mean'.
|
||||
expected_lookups = (
|
||||
# example 0, ids [2], embedding = [7, 11]
|
||||
(7., 11.),
|
||||
# example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
|
||||
(2., 3.5),
|
||||
# example 2, ids [], embedding = [0, 0]
|
||||
(0., 0.),
|
||||
# example 3, ids [1], embedding = [3, 5]
|
||||
(3., 5.),
|
||||
)
|
||||
|
||||
# Build columns.
|
||||
categorical_column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=vocabulary_size)
|
||||
embedding_column = fc.embedding_column(
|
||||
categorical_column, dimension=embedding_dimension,
|
||||
ckpt_to_load_from=ckpt_path,
|
||||
tensor_name_in_ckpt=ckpt_tensor)
|
||||
|
||||
# Provide sparse input and get dense result.
|
||||
embedding_lookup = embedding_column._get_dense_tensor(fc._LazyBuilder({
|
||||
'aaa': sparse_input
|
||||
}))
|
||||
|
||||
# Assert expected embedding variable and lookups.
|
||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.assertItemsEqual(
|
||||
('embedding_weights:0',), tuple([v.name for v in global_vars]))
|
||||
with _initialized_session():
|
||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||
self.assertAllEqual(expected_lookups, embedding_lookup.eval())
|
||||
|
||||
def test_make_linear_model(self):
|
||||
# Inputs.
|
||||
batch_size = 4
|
||||
vocabulary_size = 3
|
||||
sparse_input = sparse_tensor.SparseTensorValue(
|
||||
# example 0, ids [2]
|
||||
# example 1, ids [0, 1]
|
||||
# example 2, ids []
|
||||
# example 3, ids [1]
|
||||
indices=((0, 0), (1, 0), (1, 4), (3, 0)),
|
||||
values=(2, 0, 1, 1),
|
||||
dense_shape=(batch_size, 5))
|
||||
|
||||
# Embedding variable.
|
||||
embedding_dimension = 2
|
||||
embedding_shape = (vocabulary_size, embedding_dimension)
|
||||
zeros_embedding_values = np.zeros(embedding_shape)
|
||||
def _initializer(shape, dtype, partition_info):
|
||||
self.assertAllEqual(embedding_shape, shape)
|
||||
self.assertEqual(dtypes.float32, dtype)
|
||||
self.assertIsNone(partition_info)
|
||||
return zeros_embedding_values
|
||||
|
||||
# Build columns.
|
||||
categorical_column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=vocabulary_size)
|
||||
embedding_column = fc.embedding_column(
|
||||
categorical_column, dimension=embedding_dimension,
|
||||
initializer=_initializer)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
predictions = fc.make_linear_model({
|
||||
categorical_column.name: sparse_input
|
||||
}, (embedding_column,))
|
||||
expected_var_names = (
|
||||
'make_linear_model/bias_weights:0',
|
||||
'make_linear_model/aaa_embedding/weights:0',
|
||||
'make_linear_model/aaa_embedding/embedding_weights:0',
|
||||
)
|
||||
self.assertItemsEqual(
|
||||
expected_var_names,
|
||||
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||
trainable_vars = {
|
||||
v.name: v for v in ops.get_collection(
|
||||
ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
}
|
||||
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
|
||||
bias = trainable_vars['make_linear_model/bias_weights:0']
|
||||
embedding_weights = trainable_vars[
|
||||
'make_linear_model/aaa_embedding/embedding_weights:0']
|
||||
linear_weights = trainable_vars[
|
||||
'make_linear_model/aaa_embedding/weights:0']
|
||||
with _initialized_session():
|
||||
# Predictions with all zero weights.
|
||||
self.assertAllClose(np.zeros((1,)), bias.eval())
|
||||
self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
|
||||
self.assertAllClose(
|
||||
np.zeros((embedding_dimension, 1)), linear_weights.eval())
|
||||
self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
|
||||
|
||||
# Predictions with all non-zero weights.
|
||||
embedding_weights.assign((
|
||||
(1., 2.), # id 0
|
||||
(3., 5.), # id 1
|
||||
(7., 11.) # id 2
|
||||
)).eval()
|
||||
linear_weights.assign(((4.,), (6.,))).eval()
|
||||
# example 0, ids [2], embedding[0] = [7, 11]
|
||||
# example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
|
||||
# example 2, ids [], embedding[2] = [0, 0]
|
||||
# example 3, ids [1], embedding[3] = [3, 5]
|
||||
# sum(embeddings * linear_weights)
|
||||
# = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
|
||||
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
|
||||
|
||||
def test_make_input_layer(self):
|
||||
# Inputs.
|
||||
vocabulary_size = 3
|
||||
sparse_input = sparse_tensor.SparseTensorValue(
|
||||
# example 0, ids [2]
|
||||
# example 1, ids [0, 1]
|
||||
# example 2, ids []
|
||||
# example 3, ids [1]
|
||||
indices=((0, 0), (1, 0), (1, 4), (3, 0)),
|
||||
values=(2, 0, 1, 1),
|
||||
dense_shape=(4, 5))
|
||||
|
||||
# Embedding variable.
|
||||
embedding_dimension = 2
|
||||
embedding_values = (
|
||||
(1., 2.), # id 0
|
||||
(3., 5.), # id 1
|
||||
(7., 11.) # id 2
|
||||
)
|
||||
def _initializer(shape, dtype, partition_info):
|
||||
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
||||
self.assertEqual(dtypes.float32, dtype)
|
||||
self.assertIsNone(partition_info)
|
||||
return embedding_values
|
||||
|
||||
# Expected lookup result, using combiner='mean'.
|
||||
expected_lookups = (
|
||||
# example 0, ids [2], embedding = [7, 11]
|
||||
(7., 11.),
|
||||
# example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
|
||||
(2., 3.5),
|
||||
# example 2, ids [], embedding = [0, 0]
|
||||
(0., 0.),
|
||||
# example 3, ids [1], embedding = [3, 5]
|
||||
(3., 5.),
|
||||
)
|
||||
|
||||
# Build columns.
|
||||
categorical_column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=vocabulary_size)
|
||||
embedding_column = fc.embedding_column(
|
||||
categorical_column, dimension=embedding_dimension,
|
||||
initializer=_initializer)
|
||||
|
||||
# Provide sparse input and get dense result.
|
||||
input_layer = fc.make_input_layer({
|
||||
'aaa': sparse_input
|
||||
}, (embedding_column,))
|
||||
|
||||
# Assert expected embedding variable and lookups.
|
||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.assertItemsEqual(
|
||||
('make_input_layer/aaa_embedding/embedding_weights:0',),
|
||||
tuple([v.name for v in global_vars]))
|
||||
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
self.assertItemsEqual(
|
||||
('make_input_layer/aaa_embedding/embedding_weights:0',),
|
||||
tuple([v.name for v in trainable_vars]))
|
||||
with _initialized_session():
|
||||
self.assertAllEqual(embedding_values, trainable_vars[0].eval())
|
||||
self.assertAllEqual(expected_lookups, input_layer.eval())
|
||||
|
||||
def test_make_input_layer_not_trainable(self):
|
||||
# Inputs.
|
||||
vocabulary_size = 3
|
||||
sparse_input = sparse_tensor.SparseTensorValue(
|
||||
# example 0, ids [2]
|
||||
# example 1, ids [0, 1]
|
||||
# example 2, ids []
|
||||
# example 3, ids [1]
|
||||
indices=((0, 0), (1, 0), (1, 4), (3, 0)),
|
||||
values=(2, 0, 1, 1),
|
||||
dense_shape=(4, 5))
|
||||
|
||||
# Embedding variable.
|
||||
embedding_dimension = 2
|
||||
embedding_values = (
|
||||
(1., 2.), # id 0
|
||||
(3., 5.), # id 1
|
||||
(7., 11.) # id 2
|
||||
)
|
||||
def _initializer(shape, dtype, partition_info):
|
||||
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
|
||||
self.assertEqual(dtypes.float32, dtype)
|
||||
self.assertIsNone(partition_info)
|
||||
return embedding_values
|
||||
|
||||
# Expected lookup result, using combiner='mean'.
|
||||
expected_lookups = (
|
||||
# example 0, ids [2], embedding = [7, 11]
|
||||
(7., 11.),
|
||||
# example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
|
||||
(2., 3.5),
|
||||
# example 2, ids [], embedding = [0, 0]
|
||||
(0., 0.),
|
||||
# example 3, ids [1], embedding = [3, 5]
|
||||
(3., 5.),
|
||||
)
|
||||
|
||||
# Build columns.
|
||||
categorical_column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=vocabulary_size)
|
||||
embedding_column = fc.embedding_column(
|
||||
categorical_column, dimension=embedding_dimension,
|
||||
initializer=_initializer, trainable=False)
|
||||
|
||||
# Provide sparse input and get dense result.
|
||||
input_layer = fc.make_input_layer({
|
||||
'aaa': sparse_input
|
||||
}, (embedding_column,))
|
||||
|
||||
# Assert expected embedding variable and lookups.
|
||||
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.assertItemsEqual(
|
||||
('make_input_layer/aaa_embedding/embedding_weights:0',),
|
||||
tuple([v.name for v in global_vars]))
|
||||
self.assertItemsEqual(
|
||||
[], ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
|
||||
with _initialized_session():
|
||||
self.assertAllEqual(embedding_values, global_vars[0].eval())
|
||||
self.assertAllEqual(expected_lookups, input_layer.eval())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
BIN
tensorflow/python/feature_column/testdata/embedding.ckpt.data-00000-of-00001
vendored
Normal file
BIN
tensorflow/python/feature_column/testdata/embedding.ckpt.data-00000-of-00001
vendored
Normal file
Binary file not shown.
BIN
tensorflow/python/feature_column/testdata/embedding.ckpt.index
vendored
Normal file
BIN
tensorflow/python/feature_column/testdata/embedding.ckpt.index
vendored
Normal file
Binary file not shown.
BIN
tensorflow/python/feature_column/testdata/embedding.ckpt.meta
vendored
Normal file
BIN
tensorflow/python/feature_column/testdata/embedding.ckpt.meta
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user