Add a new hashed embedding column

Change: 126254010
This commit is contained in:
A. Unique TensorFlower 2016-06-29 16:58:19 -08:00 committed by TensorFlower Gardener
parent a2b9788ce4
commit 025e65591e
7 changed files with 443 additions and 34 deletions

View File

@ -77,17 +77,6 @@ py_test(
],
)
cuda_py_test(
name = "embedding_ops_test",
size = "small",
srcs = ["python/ops/embedding_ops_test.py"],
additional_deps = [
":framework_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "sampling_ops_test",
size = "small",

View File

@ -213,6 +213,19 @@ py_test(
],
)
py_test(
name = "embedding_ops_test",
size = "small",
srcs = ["python/layers/embedding_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":layers_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.layers.python.layers.embedding_ops import *
from tensorflow.contrib.layers.python.layers.feature_column import *
from tensorflow.contrib.layers.python.layers.feature_column_ops import *
from tensorflow.contrib.layers.python.layers.initializers import *

View File

@ -0,0 +1,191 @@
# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Embedding functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import embedding_ops as contrib_embedding_ops
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
__all__ = ["safe_embedding_lookup_sparse", "hashed_embedding_lookup",
"hashed_embedding_lookup_sparse"]
# TODO(chapelle): move the safe_embedding_lookup_sparse code here (b/29826543)
safe_embedding_lookup_sparse = contrib_embedding_ops.safe_embedding_lookup_sparse # pylint: disable=line-too-long
def hashed_embedding_lookup(params, values, dimension, name=None):
"""Looks up embeddings using parameter hashing for each value in `values`.
The i-th embedding component of a value v in `values` is found by retrieving
the weight whose index is a fingerprint of the pair (v,i).
The concept is explored as "feature hashing" for model compression in this
paper: http://arxiv.org/pdf/1504.04788.pdf
Feature hashing has the pleasant effect of allowing us to compute an embedding
without needing a pre-determined vocabulary, relieving some amount of process
complexity. It also allows for us to maintain embeddings for possibly
trillions of features with a fixed amount of memory.
Note that this is superior to out-of-vocabulary shared "hash buckets" in that
the embedding is extremely likely to be unique for each token as opposed to
being shared across probably-colliding tokens. The price is that we must
compute a hash once for each scalar in the token's embedding as opposed to
once per token.
If `params` is a list, it represents a partition of the embedding parameters.
Each tensor in the list should have the same length, except for the first ones
which may have an additional element. For instance 10 parameters can be
partitioned in 4 tensors with length `[3, 3, 2, 2]`.
Args:
params: A `Tensor` or `list` of `Tensors`.
Each tensor must be of rank 1 with fully-defined shape.
values: `Tensor` of values to be embedded.
dimension: Embedding dimension
name: An optional name for this op.
Returns:
A tensor with shape [d0, ..., dn, dimension]
with shape(values) = [d0, ..., dn]
Raises:
ValueError: if dimension is not positive or the partition size is invalid.
"""
if not isinstance(params, list):
params = [params]
with ops.op_scope(params + [dimension, values], name,
"hashed_embedding_lookup"):
if dimension <= 0:
raise ValueError("Dimension should be >0 not %d" % dimension)
num_partitions = len(params)
partition_sizes = []
for p in range(num_partitions):
shape = params[p].get_shape()
shape.assert_has_rank(1)
shape.assert_is_fully_defined()
partition_sizes.append(shape[0].value)
num_params = sum(partition_sizes) # Total number of parameters.
# Assert the size of each partition.
for p in range(num_partitions):
expected_size = (num_params - p - 1) // num_partitions + 1
if partition_sizes[p] != expected_size:
raise ValueError("Tensor %d in params has size %d, expected %d." %
(p, partition_sizes[p], expected_size))
# Flatten the values
values_shape = array_ops.shape(values)
values = array_ops.reshape(values, [-1, 1])
# With two values v1 and v2 and 3 dimensions, we will cross
# [[0, 1, 2], [0, 1, 2]] with [[v1], [v2]].
tensors_to_cross = [array_ops.tile(array_ops.expand_dims(
math_ops.range(0, dimension), 0), array_ops.shape(values)), values]
ids = sparse_feature_cross_op.sparse_feature_cross(
tensors_to_cross, hashed_output=True, num_buckets=num_params)
ids = sparse_ops.sparse_tensor_to_dense(ids)
# No need to validate the indices since we have checked the params
# dimensions and we know the largest id.
result = embedding_ops.embedding_lookup(
params, ids, partition_strategy="div", validate_indices=False)
return array_ops.reshape(result, array_ops.concat(
0, [values_shape, [dimension]]))
def hashed_embedding_lookup_sparse(params,
sparse_values,
dimension,
combiner="mean",
default_value=None,
name=None):
"""Looks up embeddings of a sparse feature using parameter hashing.
See `tf.contrib.layers.hashed_embedding_lookup` for embedding with hashing.
Args:
params: A `Tensor` or `list` of `Tensors`.
Each tensor must be of rank 1 with fully-defined shape.
sparse_values: A 2-D `SparseTensor` containing the values to be embedded.
Some rows may be empty.
dimension: Embedding dimension
combiner: A string specifying how to combine embedding results for each
entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
the default.
default_value: The value to use for an entry with no features.
name: An optional name for this op.
Returns:
Dense tensor with shape [N, dimension] with N the number of rows in
sparse_values.
Raises:
TypeError: If sparse_values is not a SparseTensor.
ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
"""
if not isinstance(params, list):
params = [params]
if not isinstance(sparse_values, ops.SparseTensor):
raise TypeError("sparse_values must be SparseTensor")
with ops.op_scope(params + [sparse_values], name,
"hashed_sparse_embedding_lookup") as scope:
# Fill in the empty rows.
if default_value is None:
# Random default values to reduce the risk of collision.
if sparse_values.dtype == dtypes.string:
default_value = "6ZxWzWOHxZ"
else:
default_value = 1288896567
sparse_values, _ = sparse_ops.sparse_fill_empty_rows(
sparse_values, default_value)
segment_ids = sparse_values.indices[:, 0]
if segment_ids.dtype != dtypes.int32:
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
values = sparse_values.values
values, idx = array_ops.unique(values)
embeddings = hashed_embedding_lookup(params, values, dimension)
if combiner == "sum":
embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
name=scope)
elif combiner == "mean":
embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
name=scope)
elif combiner == "sqrtn":
embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, segment_ids,
name=scope)
else:
raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")
return embeddings

View File

@ -100,7 +100,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, sparse_weights = self._ids_and_weights_2d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(
tf.contrib.layers.safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, sparse_weights).eval())
self.assertAllClose(embedding_lookup_result, [
@ -114,7 +114,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, sparse_weights = self._ids_and_weights_2d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(
tf.contrib.layers.safe_embedding_lookup_sparse(
embedding_weights,
sparse_ids,
sparse_weights,
@ -132,9 +132,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, _ = self._ids_and_weights_2d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
self.assertAllClose(embedding_lookup_result, [
(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
@ -148,9 +148,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, _ = self._ids_and_weights_2d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
embedding_weights = list(itertools.chain(*embedding_weights))
self.assertAllClose(embedding_lookup_result, [
@ -166,13 +166,13 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
embedding_weights[1] = embedding_weights[1].astype(np.float64)
self.assertRaises(ValueError,
tf.contrib.framework.safe_embedding_lookup_sparse,
tf.contrib.layers.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids)
embedding_weights = [
tf.constant(w, dtype=tf.float64) for w in embedding_weights
]
self.assertRaises(ValueError,
tf.contrib.framework.safe_embedding_lookup_sparse,
tf.contrib.layers.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids, sparse_weights)
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
@ -181,7 +181,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, sparse_weights = self._ids_and_weights_3d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(
tf.contrib.layers.safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, sparse_weights).eval())
self.assertAllClose(embedding_lookup_result, [
@ -195,7 +195,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, sparse_weights = self._ids_and_weights_3d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(
tf.contrib.layers.safe_embedding_lookup_sparse(
embedding_weights,
sparse_ids,
sparse_weights,
@ -214,9 +214,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, _ = self._ids_and_weights_3d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
self.assertAllClose(embedding_lookup_result, [
[(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
@ -231,9 +231,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
sparse_ids, _ = self._ids_and_weights_3d()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
None).eval())
embedding_weights = list(itertools.chain(*embedding_weights))
self.assertAllClose(embedding_lookup_result, [
@ -251,15 +251,110 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
embedding_weights[1] = embedding_weights[1].astype(np.float64)
self.assertRaises(ValueError,
tf.contrib.framework.safe_embedding_lookup_sparse,
tf.contrib.layers.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids)
embedding_weights = [
tf.constant(w, dtype=tf.float64) for w in embedding_weights
]
self.assertRaises(ValueError,
tf.contrib.framework.safe_embedding_lookup_sparse,
tf.contrib.layers.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids, sparse_weights)
class HashedEmbeddingLookupTest(tf.test.TestCase):
def setUp(self):
tf.set_random_seed(1)
def _random_weights(self, size=50, num_shards=1):
assert size > 0
assert num_shards > 0
assert num_shards <= size
embedding_weights = tf.create_partitioned_variables(
shape=[size],
slicing=[num_shards],
initializer=tf.truncated_normal_initializer(mean=0.0,
stddev=1.0,
dtype=tf.float32))
for w in embedding_weights:
w.initializer.run()
return embedding_weights
def test_hashed_embedding_consistency(self):
with self.test_session():
embedding_weights = self._random_weights()
values = tf.constant(["foo", "foo"])
embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup(
embedding_weights, values, dimension=10).eval()
self.assertAllEqual(embedding_lookup_result.shape, [2, 10])
self.assertAllEqual(embedding_lookup_result[0],
embedding_lookup_result[1])
def test_hashed_embedding_multiple_partition(self):
with self.test_session():
embedding_weights = self._random_weights(num_shards=7)
values = tf.constant([4, 4, 5])
embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup(
embedding_weights, values, dimension=5).eval()
self.assertAllEqual(embedding_lookup_result.shape, [3, 5])
self.assertAllEqual(embedding_lookup_result[0],
embedding_lookup_result[1])
# Different embedding expected for different value.
embedding_diff = np.min((embedding_lookup_result[2] -
embedding_lookup_result[0]) ** 2)
self.assertGreater(embedding_diff, 0)
def test_hashed_embedding_coverage(self):
with self.test_session():
size = 8
embedding_weights = self._random_weights(size=size, num_shards=3)
values = tf.constant(["foo"])
# Large embedding dimension to cover the full range of weights.
embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup(
embedding_weights, values, dimension=100).eval()
self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)
def test_hashed_embedding_multi_dimension(self):
with self.test_session():
embedding_weights = self._random_weights()
values = tf.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]])
embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup(
embedding_weights, values, dimension=10).eval()
self.assertAllEqual(embedding_lookup_result.shape, [2, 3, 10])
self.assertAllEqual(embedding_lookup_result[0][0],
embedding_lookup_result[1][2])
def test_hashed_embedding_lookup_sparse(self):
with self.test_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_tensor = tf.SparseTensor(values=["foo", "bar", "foo", "bar"],
indices=[[0, 0], [1, 0], [1, 1], [3, 0]],
shape=[5, 2])
embedding_lookup_result = (
tf.contrib.layers.hashed_embedding_lookup_sparse(
embedding_weights, sparse_tensor, dimension=5, combiner="mean")
.eval())
self.assertAllEqual(embedding_lookup_result.shape, [5, 5])
# Same non-zero embedding for the empty rows filled with a default value.
self.assertAllEqual(embedding_lookup_result[2],
embedding_lookup_result[4])
embedding_norm = np.sum(embedding_lookup_result[2] ** 2)
self.assertGreater(embedding_norm, 0)
self.assertAllEqual(embedding_lookup_result[1],
0.5 * (embedding_lookup_result[0] +
embedding_lookup_result[3]))
if __name__ == "__main__":
tf.test.main()

View File

@ -75,7 +75,7 @@ import abc
import collections
import math
from tensorflow.contrib.framework.python.ops import embedding_ops as contrib_embedding_ops
from tensorflow.contrib.layers.python.layers import embedding_ops
from tensorflow.contrib.layers.python.ops import bucketization_op
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
from tensorflow.contrib.lookup import lookup_ops as contrib_lookup_ops
@ -517,7 +517,7 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
# This is effectively the same format as str(self), except with our special
# treatment.
return "_EmbeddingColumn(%s)" % ", ".join(fields_values)
return "%s(%s)" % (type(self).__name__, ", ".join(fields_values))
def insert_transformed_feature(self, columns_to_tensors):
self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
@ -576,6 +576,105 @@ def embedding_column(sparse_id_column,
return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer)
class _HashedEmbeddingColumn(collections.namedtuple(
"_HashedEmbeddingColumn", ["column_name", "size", "dimension", "combiner",
"initializer"]), _EmbeddingColumn):
"""See `hashed_embedding_column`."""
def __new__(cls,
column_name,
size,
dimension,
combiner="mean",
initializer=None):
if initializer is not None and not callable(initializer):
raise ValueError("initializer must be callable if specified.")
if initializer is None:
stddev = 0.1
# TODO(b/25671353): Better initial value?
initializer = init_ops.truncated_normal_initializer(mean=0.0,
stddev=stddev)
return super(_HashedEmbeddingColumn, cls).__new__(cls, column_name, size,
dimension, combiner,
initializer)
@property
def name(self):
return self.column_name + "_embedding"
@property
def config(self):
return {self.column_name: parsing_ops.VarLenFeature(dtypes.string)}
def insert_transformed_feature(self, columns_to_tensors):
columns_to_tensors[self] = columns_to_tensors[self.column_name]
def to_dnn_input_layer(self,
input_tensor,
weight_collections=None,
trainable=True):
# Same heuristic for the number of shards as _max_size_embedding_partitioner
max_shard_bytes = (64 << 20) - 1
shards = self.size * 4.0 / max_shard_bytes
shards = max(1, int(math.ceil(shards)))
embeddings = partitioned_variables.create_partitioned_variables(
shape=[self.size],
slicing=[shards],
initializer=self.initializer,
dtype=dtypes.float32,
collections=_add_variable_collection(weight_collections),
name=self.name + "_weights",
reuse=False,
trainable=trainable)
return embedding_ops.hashed_embedding_lookup_sparse(
embeddings, input_tensor, self.dimension, name=self.name + "_lookup")
def hashed_embedding_column(column_name,
size,
dimension,
combiner="mean",
initializer=None):
"""Creates an embedding column of a sparse feature using parameter hashing.
The i-th embedding component of a value v is found by retrieving an
embedding weight whose index is a fingerprint of the pair (v,i).
Args:
column_name: A string defining sparse column name.
size: An integer specifying the number of parameters in the embedding layer.
dimension: An integer specifying dimension of the embedding.
combiner: A string specifying how to reduce if there are multiple entries
in a single row. Currently "mean", "sqrtn" and "sum" are supported. Each
of this can be thought as example level normalizations on the column:
* "sum": do not normalize features in the column
* "mean": do l1 normalization on features in the column
* "sqrtn": do l2 normalization on features in the column
For more information: `tf.embedding_lookup_sparse`.
initializer: A variable initializer function to be used in embedding
variable initialization. If not specified, defaults to
`tf.truncated_normal_initializer` with mean 0 and standard deviation 0.1.
Returns:
A _HashedEmbeddingColumn.
Raises:
ValueError: if dimension or size is not a positive integer; or if combiner
is not supported.
"""
if (dimension < 1) or (size < 1):
raise ValueError("Dimension and size must be greater than 0.")
if combiner not in ("mean", "sqrtn", "sum"):
raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")
return _HashedEmbeddingColumn(column_name, size, dimension, combiner,
initializer)
class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
"_RealValuedColumn",
["column_name", "dimension", "default_value", "dtype"])):
@ -1237,7 +1336,7 @@ def _create_embedding_lookup(input_tensor, vocab_size, dimension,
reuse=False,
trainable=trainable)
return contrib_embedding_ops.safe_embedding_lookup_sparse(
return embedding_ops.safe_embedding_lookup_sparse(
embeddings,
input_tensor,
default_id=0,

View File

@ -282,6 +282,27 @@ class InputLayerTest(tf.test.TestCase):
tf.initialize_all_variables().run()
self.assertAllEqual(output.eval().shape, [2, 10])
def testHashedEmbeddingColumn(self):
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo", "omar"],
indices=[[0, 0], [1, 0], [1, 1], [2, 0]],
shape=[3, 2])
features = {"wire": wire_tensor}
# Big enough hash space so that hopefully there is no collision
embedded_sparse = tf.contrib.layers.hashed_embedding_column("wire", 1000, 3)
output = tf.contrib.layers.input_from_feature_columns(
features, [embedded_sparse], weight_collections=["my_collection"])
weights = tf.get_collection("my_collection")
grad = tf.gradients(output, weights)
with self.test_session():
tf.initialize_all_variables().run()
gradient_values = []
# Collect the gradient from the different partitions (one in this test)
for p in range(len(grad)):
gradient_values.extend(grad[p].values.eval())
gradient_values.sort()
self.assertAllEqual(gradient_values, [0.5]*6 + [2]*3)
def testEmbeddingColumnWithInitializer(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],