diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index ac7aef96ac1..8e26465f262 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -39,6 +39,7 @@ py_library( "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", + "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", @@ -48,6 +49,9 @@ py_library( filegroup( name = "vocabulary_testdata", srcs = [ + "testdata/embedding.ckpt.data-00000-of-00001", + "testdata/embedding.ckpt.index", + "testdata/embedding.ckpt.meta", "testdata/warriors_vocabulary.txt", "testdata/wire_vocabulary.txt", ], diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 50d7e00fd92..557a543ef24 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -123,6 +123,7 @@ from __future__ import print_function import abc import collections +import math import numpy as np import six @@ -144,6 +145,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_utils from tensorflow.python.util import nest @@ -304,7 +306,7 @@ def make_linear_model(features, predictions_no_bias = math_ops.add_n( weigthed_sums, name='weighted_sum_no_bias') bias = variable_scope.get_variable( - 'bias_weight', + 'bias_weights', shape=[units], initializer=init_ops.zeros_initializer(), trainable=trainable, @@ -416,6 +418,88 @@ def make_parse_example_spec(feature_columns): return result +def embedding_column( + categorical_column, dimension, combiner='mean', initializer=None, + ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None, + trainable=True): + """`_DenseColumn` that converts from sparse, categorical input. + + Use this when your inputs are sparse, but you want to convert them to a dense + representation (e.g., to feed to a DNN). + + Inputs must be `SparseTensor` by way of the provided + `categorical_column._get_sparse_tensors`. + + Any of the `categorical_column_*` columns can be provided as input. Here is an + example embedding of an identity column for a DNN model: + + ```python + video_id = categorical_column_with_identity( + key='video_id', num_buckets=1000000, default_value=0) + columns = [embedding_column(video_id, 9),...] + features = tf.parse_example(..., features=parse_example_spec(columns)) + dense_tensor = make_input_layer(features, columns) + ``` + + Args: + categorical_column: A `_CategoricalColumn` created by a + `categorical_column_with_*` function. This column produces the sparse IDs + that are inputs to the embedding lookup. + dimension: An integer specifying dimension of the embedding, must be > 0. + combiner: A string specifying how to reduce if there are multiple entries + in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with + 'mean' the default. 'sqrtn' often achieves good accuracy, in particular + with bag-of-words columns. Each of this can be thought as example level + normalizations on the column. For more information, see + `tf.embedding_lookup_sparse`. + initializer: A variable initializer function to be used in embedding + variable initialization. If not specified, defaults to + `tf.truncated_normal_initializer` with mean `0.0` and standard deviation + `1/sqrt(categorical_column._num_buckets)`. + ckpt_to_load_from: String representing checkpoint name/pattern fromwhich to + restore column weights. Required if `tensor_name_in_ckpt` is not `None`. + tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from + which to restore the column weights. Required if `ckpt_to_load_from` is + not `None`. + max_norm: If not `None`, embedding values are l2-normalized to this value. + trainable: Whether or not the embedding is trainable. Default is True. + + Returns: + `_DenseColumn` that converts from sparse input. + + Raises: + ValueError: if `dimension` not > 0. + ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` + is specified. + ValueError: if `initializer` is specified and is not callable. + """ + if (dimension is None) or (dimension < 1): + raise ValueError('Invalid dimension {}.'.format(dimension)) + if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): + raise ValueError('Must specify both `ckpt_to_load_from` and ' + '`tensor_name_in_ckpt` or none of them.') + + if (initializer is not None) and (not callable(initializer)): + raise ValueError('initializer must be callable if specified. ' + 'Embedding of column_name: {}'.format( + categorical_column.name)) + if initializer is None: + # pylint: disable=protected-access + initializer = init_ops.truncated_normal_initializer( + mean=0.0, stddev=1 / math.sqrt(dimension)) + # pylint: enable=protected-access + + return _EmbeddingColumn( + categorical_column=categorical_column, + dimension=dimension, + combiner=combiner, + initializer=initializer, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable) + + def numeric_column(key, shape=(1,), default_value=None, @@ -568,7 +652,7 @@ def categorical_column_with_hash_bucket(key, want to distribute your inputs into a finite number of buckets by hashing. output_id = Hash(input_feature_string) % bucket_size - `features[key]` are either `Tensor` or `SparseTensor`. If `Tensor`, missing + `features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int and `''` for string. Note that these values are independent of the `default_value` argument. @@ -625,7 +709,7 @@ def categorical_column_with_vocabulary_file( `num_oov_buckets` and `default_value` to specify how to include out-of-vocabulary values. - `features[key]` are either `Tensor` or `SparseTensor`. If `Tensor`, missing + `features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int and `''` for string. Note that these values are independent of the `default_value` argument. @@ -693,7 +777,6 @@ def categorical_column_with_vocabulary_file( if not vocabulary_file: raise ValueError('Missing vocabulary_file in {}.'.format(key)) # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`. - # TODO(ptucker): Should we fail for vocabulary_size==1? if (vocabulary_size is None) or (vocabulary_size < 1): raise ValueError('Invalid vocabulary_size in {}.'.format(key)) if num_oov_buckets: @@ -719,14 +802,14 @@ def categorical_column_with_vocabulary_list( """A `_CategoricalColumn` with in-memory vocabulary. Logic for feature f is: - id = f in vocabulary_list ? vocabulary_list.index(f) : default_value + id = f in vocabulary_list ? vocabulary_list.index_of(f) : default_value Use this when your inputs are in string or integer format, and you have an in-memory vocabulary mapping each value to an integer ID. By default, out-of-vocabulary values are ignored. Use `default_value` to specify how to include out-of-vocabulary values. - `features[key]` are either `Tensor` or `SparseTensor`. If `Tensor`, missing + `features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int and `''` for string. Note that these values are independent of the `default_value` argument. @@ -795,11 +878,16 @@ def categorical_column_with_vocabulary_list( def categorical_column_with_identity(key, num_buckets, default_value=None): """A `_CategoricalColumn` that returns identity values. - Use this when your inputs are integers in the range `[0, num_buckets)`. Values - outside this range will result in `default_value` if specified, otherwise it - will fail. + Use this when your inputs are integers in the range `[0, num_buckets)`, and + you want to use the input value itself as the categorical ID. Values outside + this range will result in `default_value` if specified, otherwise it will + fail. - `features[key]` are either `Tensor` or `SparseTensor`. If `Tensor`, missing + Typically, this is used for contiguous ranges of integer indexes, but + it doesn't have to be. This might be inefficient, however, if many of IDs + are unused. Consider `categorical_column_with_hash_bucket` in that case. + + `features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int and `''` for string. Note that these values are independent of the `default_value` argument. @@ -960,15 +1048,15 @@ class _DenseColumn(_FeatureColumn): @abc.abstractproperty def _variable_shape(self): - """Returns a `TensorShape` representing the shape of the dense `Tensor`.""" + """`TensorShape` of `_get_dense_tensor`, without batch dimension.""" pass @abc.abstractmethod def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): """Returns a `Tensor`. - The output of this function will be used by model-buildier-functions. For - example the pseudo code of `make_input_layer` will be like that: + The output of this function will be used by model-builder-functions. For + example the pseudo code of `make_input_layer` will be like: ```python def make_input_layer(features, feature_columns, ...): @@ -982,6 +1070,9 @@ class _DenseColumn(_FeatureColumn): will be created) are added. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see ${tf.Variable}). + + Returns: + `Tensor` of shape [batch_size] + `_variable_shape`. """ pass @@ -997,7 +1088,7 @@ def _create_dense_column_weighted_sum( batch_size = array_ops.shape(tensor)[0] tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements)) weight = variable_scope.get_variable( - name='weight', + name='weights', shape=[num_elements, units], initializer=init_ops.zeros_initializer(), trainable=trainable, @@ -1060,7 +1151,7 @@ def _create_categorical_column_weighted_sum( trainable=trainable) weight = variable_scope.get_variable( name='weight', - shape=[column._num_buckets, units], # pylint: disable=protected-access + shape=(column._num_buckets, units), # pylint: disable=protected-access initializer=init_ops.zeros_initializer(), trainable=trainable, collections=weight_collections) @@ -1365,6 +1456,74 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn, return _CategoricalColumn.IdWeightPair(sparse_tensor, None) +class _EmbeddingColumn( + _DenseColumn, + collections.namedtuple('_EmbeddingColumn', ( + 'categorical_column', 'dimension', 'combiner', 'initializer', + 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable' + ))): + """See `_embedding_column`.""" + + @property + def name(self): + if not hasattr(self, '_name'): + self._name = '{}_embedding'.format(self.categorical_column.name) + return self._name + + @property + def _parse_example_config(self): + # pylint: disable=protected-access + return self.categorical_column._parse_example_config + # pylint: enable=protected-access + + def _transform_feature(self, inputs): + return inputs.get(self.categorical_column) + + @property + def _variable_shape(self): + if not hasattr(self, '_shape'): + self._shape = tensor_shape.vector(self.dimension) + return self._shape + + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + # Get sparse IDs and weights. + # pylint: disable=protected-access + sparse_tensors = self.categorical_column._get_sparse_tensors( + inputs, weight_collections=weight_collections, trainable=trainable) + # pylint: enable=protected-access + sparse_ids = sparse_tensors.id_tensor + sparse_weights = sparse_tensors.weight_tensor + + # Create embedding weight, and restore from checkpoint if necessary. + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + # pylint: disable=protected-access + shape=(self.categorical_column._num_buckets, self.dimension), + # pylint: enable=protected-access + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable and trainable, + collections=weight_collections) + if self.ckpt_to_load_from is not None: + to_restore = embedding_weights + if isinstance(to_restore, variables.PartitionedVariable): + # pylint: disable=protected-access + to_restore = to_restore._get_variable_list() + # pylint: enable=protected-access + checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, { + self.tensor_name_in_ckpt: to_restore + }) + + # Return embedding lookup result. + return _safe_embedding_lookup_sparse( + embedding_weights=embedding_weights, + sparse_ids=sparse_ids, + sparse_weights=sparse_weights, + combiner=self.combiner, + name='%s_weights' % self.name, + max_norm=self.max_norm) + + def _create_tuple(shape, value): """Returns a tuple with given shape and filled with value.""" if shape: @@ -1689,7 +1848,7 @@ class _IdentityCategoricalColumn( def _safe_embedding_lookup_sparse(embedding_weights, sparse_ids, sparse_weights=None, - combiner=None, + combiner='mean', default_id=None, name=None, partition_strategy='div', @@ -1727,7 +1886,7 @@ def _safe_embedding_lookup_sparse(embedding_weights, name: A name for this operation (optional). partition_strategy: A string specifying the partitioning strategy. Currently `"div"` and `"mod"` are supported. Default is `"div"`. - max_norm: If not None, all embeddings are l2-normalized to max_norm before + max_norm: If not `None`, all embeddings are l2-normalized to max_norm before combining. @@ -1737,10 +1896,6 @@ def _safe_embedding_lookup_sparse(embedding_weights, Raises: ValueError: if `embedding_weights` is empty. """ - if combiner is None: - logging.warn('The default value of combiner will change from \"mean\" ' - 'to \"sqrtn\" after 2016/11/01.') - combiner = 'mean' if embedding_weights is None: raise ValueError('Missing embedding_weights %s.' % embedding_weights) if isinstance(embedding_weights, variables.PartitionedVariable): @@ -1751,8 +1906,6 @@ def _safe_embedding_lookup_sparse(embedding_weights, raise ValueError('Missing embedding_weights %s.' % embedding_weights) dtype = sparse_weights.dtype if sparse_weights is not None else None - if isinstance(embedding_weights, variables.PartitionedVariable): - embedding_weights = list(embedding_weights) embedding_weights = [ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights ] diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index cf52d2df4fd..56fe5180d8e 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -577,7 +577,6 @@ class HashedCategoricalColumnTest(test.TestCase): fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.float32) def test_deep_copy(self): - """Tests deepcopy of categorical_column_with_hash_bucket.""" original = fc.categorical_column_with_hash_bucket('aaa', 10) for column in (original, copy.deepcopy(original)): self.assertEqual('aaa', column.name) @@ -692,6 +691,19 @@ class HashedCategoricalColumnTest(test.TestCase): self.assertIsNone(id_weight_pair.weight_tensor) self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor) + def test_get_sparse_tensors_weight_collections(self): + column = fc.categorical_column_with_hash_bucket('aaa', 10) + inputs = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs}), weight_collections=('my_weights',)) + + self.assertItemsEqual( + [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertItemsEqual([], ops.get_collection('my_weights')) + def test_get_sparse_tensors_dense_input(self): hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10) builder = fc._LazyBuilder({ @@ -725,7 +737,7 @@ class HashedCategoricalColumnTest(test.TestCase): def get_linear_model_bias(): with variable_scope.variable_scope('make_linear_model', reuse=True): - return variable_scope.get_variable('bias_weight') + return variable_scope.get_variable('bias_weights') def get_linear_model_column_var(column): @@ -885,8 +897,8 @@ class MakeLinearModelTest(test.TestCase): bias = get_linear_model_bias() price_var = get_linear_model_column_var(price) with _initialized_session() as sess: - self.assertAllClose([0., 0., 0.], bias.eval()) - self.assertAllClose([[0., 0., 0.]], price_var.eval()) + self.assertAllClose(np.zeros((3,)), bias.eval()) + self.assertAllClose(np.zeros((1, 3)), price_var.eval()) sess.run(price_var.assign([[10., 100., 1000.]])) sess.run(bias.assign([5., 6., 7.])) self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]], @@ -904,8 +916,8 @@ class MakeLinearModelTest(test.TestCase): bias = get_linear_model_bias() wire_cast_var = get_linear_model_column_var(wire_cast) with _initialized_session() as sess: - self.assertAllClose([0., 0., 0.], bias.eval()) - self.assertAllClose([[0.] * 3] * 4, wire_cast_var.eval()) + self.assertAllClose(np.zeros((3,)), bias.eval()) + self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval()) sess.run( wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.], [ 1000., 1100., 1200. @@ -950,8 +962,8 @@ class MakeLinearModelTest(test.TestCase): bias = get_linear_model_bias() price_var = get_linear_model_column_var(price) with _initialized_session() as sess: - self.assertAllClose([0., 0., 0.], bias.eval()) - self.assertAllClose([[0., 0., 0.], [0., 0., 0.]], price_var.eval()) + self.assertAllClose(np.zeros((3,)), bias.eval()) + self.assertAllClose(np.zeros((2, 3)), price_var.eval()) sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]])) sess.run(bias.assign([2., 3., 4.])) self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]], @@ -1327,7 +1339,6 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): }, column._parse_example_config) def test_deep_copy(self): - """Tests deepcopy of categorical_column_with_hash_bucket.""" original = fc.categorical_column_with_vocabulary_file( key='aaa', vocabulary_file='path_to_file', vocabulary_size=3, num_oov_buckets=4, dtype=dtypes.int32) @@ -1476,6 +1487,22 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): dense_shape=inputs.dense_shape), id_tensor.eval()) + def test_get_sparse_tensors_weight_collections(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + inputs = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs}), weight_collections=('my_weights',)) + + self.assertItemsEqual( + [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertItemsEqual([], ops.get_collection('my_weights')) + def test_get_sparse_tensors_dense_input(self): column = fc.categorical_column_with_vocabulary_file( key='aaa', @@ -1684,7 +1711,6 @@ class VocabularyListCategoricalColumnTest(test.TestCase): }, column._parse_example_config) def test_deep_copy(self): - """Tests deepcopy of categorical_column_with_hash_bucket.""" original = fc.categorical_column_with_vocabulary_list( key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32) for column in (original, copy.deepcopy(original)): @@ -1778,6 +1804,21 @@ class VocabularyListCategoricalColumnTest(test.TestCase): dense_shape=inputs.dense_shape), id_weight_pair.id_tensor.eval()) + def test_get_sparse_tensors_weight_collections(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs}), weight_collections=('my_weights',)) + + self.assertItemsEqual( + [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertItemsEqual([], ops.get_collection('my_weights')) + def test_get_sparse_tensors_dense_input(self): column = fc.categorical_column_with_vocabulary_list( key='aaa', @@ -1894,7 +1935,6 @@ class IdentityCategoricalColumnTest(test.TestCase): }, column._parse_example_config) def test_deep_copy(self): - """Tests deepcopy of categorical_column_with_hash_bucket.""" original = fc.categorical_column_with_identity(key='aaa', num_buckets=3) for column in (original, copy.deepcopy(original)): self.assertEqual('aaa', column.name) @@ -1948,6 +1988,19 @@ class IdentityCategoricalColumnTest(test.TestCase): dense_shape=inputs.dense_shape), id_weight_pair.id_tensor.eval()) + def test_get_sparse_tensors_weight_collections(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 1, 0), + dense_shape=(2, 2)) + column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs}), weight_collections=('my_weights',)) + + self.assertItemsEqual( + [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertItemsEqual([], ops.get_collection('my_weights')) + def test_get_sparse_tensors_dense_input(self): column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ @@ -2221,5 +2274,553 @@ class IndicatorColumnTest(test.TestCase): self.assertAllClose([[0., 1., 1., 0.]], net.eval()) +class EmbeddingColumnTest(test.TestCase): + + def test_defaults(self): + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + embedding_dimension = 2 + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension) + self.assertIs(categorical_column, embedding_column.categorical_column) + self.assertEqual(embedding_dimension, embedding_column.dimension) + self.assertEqual('mean', embedding_column.combiner) + self.assertIsNotNone(embedding_column.initializer) + self.assertIsNone(embedding_column.ckpt_to_load_from) + self.assertIsNone(embedding_column.tensor_name_in_ckpt) + self.assertIsNone(embedding_column.max_norm) + self.assertTrue(embedding_column.trainable) + self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual( + (embedding_dimension,), embedding_column._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column._parse_example_config) + + def test_all_constructor_args(self): + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + embedding_dimension = 2 + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + combiner='my_combiner', initializer=lambda: 'my_initializer', + ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor', + max_norm=42., trainable=False) + self.assertIs(categorical_column, embedding_column.categorical_column) + self.assertEqual(embedding_dimension, embedding_column.dimension) + self.assertEqual('my_combiner', embedding_column.combiner) + self.assertEqual('my_initializer', embedding_column.initializer()) + self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from) + self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt) + self.assertEqual(42., embedding_column.max_norm) + self.assertFalse(embedding_column.trainable) + self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual( + (embedding_dimension,), embedding_column._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column._parse_example_config) + + def test_deep_copy(self): + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + embedding_dimension = 2 + original = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + combiner='my_combiner', initializer=lambda: 'my_initializer', + ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor', + max_norm=42., trainable=False) + for embedding_column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', embedding_column.categorical_column.name) + self.assertEqual(3, embedding_column.categorical_column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column.categorical_column._parse_example_config) + + self.assertEqual(embedding_dimension, embedding_column.dimension) + self.assertEqual('my_combiner', embedding_column.combiner) + self.assertEqual('my_initializer', embedding_column.initializer()) + self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from) + self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt) + self.assertEqual(42., embedding_column.max_norm) + self.assertFalse(embedding_column.trainable) + self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual( + (embedding_dimension,), embedding_column._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column._parse_example_config) + + def test_invalid_initializer(self): + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + with self.assertRaisesRegexp(ValueError, 'initializer must be callable'): + fc.embedding_column(categorical_column, dimension=2, initializer='not_fn') + + def test_get_dense_tensor(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + + # Provide sparse input and get dense result. + embedding_lookup = embedding_column._get_dense_tensor(fc._LazyBuilder({ + 'aaa': sparse_input + })) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, embedding_lookup.eval()) + + def test_get_dense_tensor_3d(self): + # Inputs. + vocabulary_size = 4 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0, 0), (1, 1, 0), (1, 1, 4), (3, 0, 0), (3, 1, 2)), + values=(2, 0, 1, 1, 2), + dense_shape=(4, 2, 5)) + + # Embedding variable. + embedding_dimension = 3 + embedding_values = ( + (1., 2., 4.), # id 0 + (3., 5., 1.), # id 1 + (7., 11., 2.), # id 2 + (2., 7., 12.) # id 3 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [[2], []], embedding = [[7, 11, 2], [0, 0, 0]] + ((7., 11., 2.), (0., 0., 0.)), + # example 1, ids [[], [0, 1]], embedding + # = mean([[], [1, 2, 4] + [3, 5, 1]]) = [[0, 0, 0], [2, 3.5, 2.5]] + ((0., 0., 0.), (2., 3.5, 2.5)), + # example 2, ids [[], []], embedding = [[0, 0, 0], [0, 0, 0]] + ((0., 0., 0.), (0., 0., 0.)), + # example 3, ids [[1], [2]], embedding = [[3, 5, 1], [7, 11, 2]] + ((3., 5., 1.), (7., 11., 2.)), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + + # Provide sparse input and get dense result. + embedding_lookup = embedding_column._get_dense_tensor(fc._LazyBuilder({ + 'aaa': sparse_input + })) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, embedding_lookup.eval()) + + def test_get_dense_tensor_weight_collections(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + embedding_column = fc.embedding_column(categorical_column, dimension=2) + + # Provide sparse input and get dense result. + embedding_column._get_dense_tensor(fc._LazyBuilder({ + 'aaa': sparse_input + }), weight_collections=('my_vars',)) + + # Assert expected embedding variable and lookups. + self.assertItemsEqual( + [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + my_vars = ops.get_collection('my_vars') + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in my_vars])) + + def test_get_dense_tensor_placeholder_inputs(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + + # Provide sparse input and get dense result. + input_indices = array_ops.placeholder(dtype=dtypes.int64) + input_values = array_ops.placeholder(dtype=dtypes.int64) + input_shape = array_ops.placeholder(dtype=dtypes.int64) + embedding_lookup = embedding_column._get_dense_tensor(fc._LazyBuilder({ + 'aaa': sparse_tensor.SparseTensorValue( + indices=input_indices, + values=input_values, + dense_shape=input_shape) + })) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, embedding_lookup.eval( + feed_dict={ + input_indices: sparse_input.indices, + input_values: sparse_input.values, + input_shape: sparse_input.dense_shape, + })) + + def test_get_dense_tensor_restore_from_ckpt(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. The checkpoint file contains _embedding_values. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + ckpt_path = test.test_src_dir_path( + 'python/feature_column/testdata/embedding.ckpt') + ckpt_tensor = 'my_embedding' + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + ckpt_to_load_from=ckpt_path, + tensor_name_in_ckpt=ckpt_tensor) + + # Provide sparse input and get dense result. + embedding_lookup = embedding_column._get_dense_tensor(fc._LazyBuilder({ + 'aaa': sparse_input + })) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, embedding_lookup.eval()) + + def test_make_linear_model(self): + # Inputs. + batch_size = 4 + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(batch_size, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_shape = (vocabulary_size, embedding_dimension) + zeros_embedding_values = np.zeros(embedding_shape) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual(embedding_shape, shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return zeros_embedding_values + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + + with ops.Graph().as_default(): + predictions = fc.make_linear_model({ + categorical_column.name: sparse_input + }, (embedding_column,)) + expected_var_names = ( + 'make_linear_model/bias_weights:0', + 'make_linear_model/aaa_embedding/weights:0', + 'make_linear_model/aaa_embedding/embedding_weights:0', + ) + self.assertItemsEqual( + expected_var_names, + [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + trainable_vars = { + v.name: v for v in ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES) + } + self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + bias = trainable_vars['make_linear_model/bias_weights:0'] + embedding_weights = trainable_vars[ + 'make_linear_model/aaa_embedding/embedding_weights:0'] + linear_weights = trainable_vars[ + 'make_linear_model/aaa_embedding/weights:0'] + with _initialized_session(): + # Predictions with all zero weights. + self.assertAllClose(np.zeros((1,)), bias.eval()) + self.assertAllClose(zeros_embedding_values, embedding_weights.eval()) + self.assertAllClose( + np.zeros((embedding_dimension, 1)), linear_weights.eval()) + self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval()) + + # Predictions with all non-zero weights. + embedding_weights.assign(( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + )).eval() + linear_weights.assign(((4.,), (6.,))).eval() + # example 0, ids [2], embedding[0] = [7, 11] + # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5] + # example 2, ids [], embedding[2] = [0, 0] + # example 3, ids [1], embedding[3] = [3, 5] + # sum(embeddings * linear_weights) + # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42] + self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval()) + + def test_make_input_layer(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + + # Provide sparse input and get dense result. + input_layer = fc.make_input_layer({ + 'aaa': sparse_input + }, (embedding_column,)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('make_input_layer/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual( + ('make_input_layer/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in trainable_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, trainable_vars[0].eval()) + self.assertAllEqual(expected_lookups, input_layer.eval()) + + def test_make_input_layer_not_trainable(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer, trainable=False) + + # Provide sparse input and get dense result. + input_layer = fc.make_input_layer({ + 'aaa': sparse_input + }, (embedding_column,)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('make_input_layer/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + self.assertItemsEqual( + [], ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, input_layer.eval()) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/feature_column/testdata/embedding.ckpt.data-00000-of-00001 b/tensorflow/python/feature_column/testdata/embedding.ckpt.data-00000-of-00001 new file mode 100644 index 00000000000..5cc36d86d60 Binary files /dev/null and b/tensorflow/python/feature_column/testdata/embedding.ckpt.data-00000-of-00001 differ diff --git a/tensorflow/python/feature_column/testdata/embedding.ckpt.index b/tensorflow/python/feature_column/testdata/embedding.ckpt.index new file mode 100644 index 00000000000..c1f35a8fcff Binary files /dev/null and b/tensorflow/python/feature_column/testdata/embedding.ckpt.index differ diff --git a/tensorflow/python/feature_column/testdata/embedding.ckpt.meta b/tensorflow/python/feature_column/testdata/embedding.ckpt.meta new file mode 100644 index 00000000000..65bc3f2becb Binary files /dev/null and b/tensorflow/python/feature_column/testdata/embedding.ckpt.meta differ