Adds a method to lookup embedding results, accounting for invalid IDs and empty features.

Change: 122271376
This commit is contained in:
A. Unique TensorFlower 2016-05-13 09:25:32 -08:00 committed by TensorFlower Gardener
parent 7057cc8a31
commit f24b02735f
5 changed files with 281 additions and 0 deletions

View File

@ -7,6 +7,8 @@ exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
name = "framework_py",
srcs = [
@ -14,6 +16,7 @@ py_library(
"python/framework/__init__.py",
"python/framework/tensor_util.py",
"python/ops/__init__.py",
"python/ops/embedding_ops.py",
"python/ops/ops.py",
"python/ops/variables.py",
],
@ -56,6 +59,17 @@ py_test(
],
)
cuda_py_test(
name = "embedding_ops_test",
size = "small",
srcs = ["python/ops/embedding_ops_test.py"],
additional_deps = [
":framework_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -27,6 +27,7 @@
@@is_strictly_increasing
@@local_variable
@@reduce_sum_n
@@safe_embedding_lookup_sparse
@@with_shape
@@with_same_shape

View File

@ -20,5 +20,6 @@ from __future__ import print_function
# TODO(ptucker): Add these to tf.contrib.variables?
# pylint: disable=wildcard-import
from tensorflow.contrib.framework.python.ops.embedding_ops import *
from tensorflow.contrib.framework.python.ops.ops import *
from tensorflow.contrib.framework.python.ops.variables import *

View File

@ -0,0 +1,118 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Embedding functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops as tf_embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
__all__ = ["safe_embedding_lookup_sparse",]
def safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, sparse_weights=None, combiner="mean",
default_id=None, name=None, partition_strategy="div"):
"""Lookup embedding results, accounting for invalid IDs and empty features.
The partitioned embedding in `embedding_weights` must all be the same shape
except for the first dimension. The first dimension is allowed to vary as the
vocabulary size is not necessarily a multiple of `P`.
Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
with non-positive weight. For an entry with no features, the embedding vector
for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
Args:
embedding_weights: A list of `P` float tensors or values representing
partitioned embedding tensors.
sparse_ids: `SparseTensor` of shape `[batch_size, ?]` containing the ids.
sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
float weights corresponding to `sparse_ids`, or `None` if all weights
are be assumed to be 1.0.
combiner: A string specifying how to combine embedding results for each
entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
the default.
default_id: The id to use for an entry with no features.
name: A name for this operation (optional).
partition_strategy: A string specifying the partitioning strategy.
Currently `"div"` and `"mod"` are supported. Default is `"div"`.
Returns:
Dense tensor of shape `[batch_size, embed_dim]`.
Raises:
ValueError: if `embedding_weights` is empty.
"""
if embedding_weights is None or len(embedding_weights) < 1:
raise ValueError("Missing embedding_weights %s." % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights else None
embedding_weights = [
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights]
contrib_tensor_util.assert_same_float_dtype(
embedding_weights + [sparse_weights])
with ops.op_scope(
embedding_weights + [sparse_ids, sparse_weights], name,
"embedding_lookup") as scope:
# Prune invalid ids and weights.
sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
# Fill in dummy values for empty features, if necessary.
sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(
sparse_ids, default_id or 0)
if sparse_weights:
sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(
sparse_weights, 1.0)
result = tf_embedding_ops.embedding_lookup_sparse(
embedding_weights, sparse_ids, sparse_weights, combiner=combiner,
partition_strategy=partition_strategy,
name=None if default_id is None else scope)
if default_id is None:
# Broadcast is_row_empty to the same shape as embedding_lookup_result,
# for use in Select.
is_row_empty = array_ops.tile(
array_ops.reshape(is_row_empty, [-1, 1]),
array_ops.pack([1, array_ops.shape(result)[1]]))
result = math_ops.select(
is_row_empty, array_ops.zeros_like(result), result, name=scope)
return result
def _prune_invalid_ids(sparse_ids, sparse_weights=None,
filter_invalid_weights=True):
"""Prune invalid IDs (< 0) from the input ids and weights."""
is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
if sparse_weights and filter_invalid_weights:
is_id_valid = math_ops.logical_and(
is_id_valid, math_ops.greater(sparse_weights.values, 0))
sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
if sparse_weights:
sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
return sparse_ids, sparse_weights

View File

@ -0,0 +1,147 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""embedding_ops tests."""
# pylint: disable=unused-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import math
import numpy as np
import tensorflow as tf
class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
def _random_ids_and_weights(self, vocab_size=4, embed_dim=4, num_shards=1):
assert vocab_size > 0
assert embed_dim > 0
assert num_shards > 0
assert num_shards <= vocab_size
embedding_weights = tf.create_partitioned_variables(
shape=[vocab_size, embed_dim],
slicing=[num_shards, 1],
initializer=tf.truncated_normal_initializer(
mean=0.0,
stddev=1.0 / math.sqrt(vocab_size),
dtype=tf.float32))
for w in embedding_weights:
w.initializer.run()
embedding_weights = [w.eval() for w in embedding_weights]
# Each row demonstrates a test case:
# Row 0: multiple valid ids, 1 invalid id, weighted mean
# Row 1: all ids are invalid (leaving no valid ids after pruning)
# Row 2: no ids to begin with
# Row 3: single id
# Row 4: all ids have <=0 weight
indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [4, 0], [4, 1]]
ids = [0, 1, -1, -1, 2, 0, 1]
weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5]
shape = [5, 4]
sparse_ids = tf.SparseTensor(
tf.constant(indices, tf.int64),
tf.constant(ids, tf.int64),
tf.constant(shape, tf.int64))
sparse_weights = tf.SparseTensor(
tf.constant(indices, tf.int64),
tf.constant(weights, tf.float32),
tf.constant(shape, tf.int64))
return embedding_weights, sparse_ids, sparse_weights
def test_safe_embedding_lookup_sparse_return_zero_vector(self):
with self.test_session():
embedding_weights, sparse_ids, sparse_weights = (
self._random_ids_and_weights())
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, sparse_weights).eval())
self.assertAllClose(embedding_lookup_result, [
(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0,
[0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4
])
def test_safe_embedding_lookup_sparse_return_special_vector(self):
with self.test_session():
embedding_weights, sparse_ids, sparse_weights = (
self._random_ids_and_weights())
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(
embedding_weights, sparse_ids,
sparse_weights, default_id=3).eval())
self.assertAllClose(embedding_lookup_result, [
(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0,
embedding_weights[0][3], embedding_weights[0][3],
embedding_weights[0][2], embedding_weights[0][3]])
def test_safe_embedding_lookup_sparse_no_weights(self):
with self.test_session():
embedding_weights, sparse_ids, _ = self._random_ids_and_weights()
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, None).eval())
self.assertAllClose(embedding_lookup_result, [
(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
[0] * 4, embedding_weights[0][2],
(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0
])
def test_safe_embedding_lookup_sparse_partitioned(self):
with self.test_session():
embedding_weights, sparse_ids, _ = self._random_ids_and_weights(
vocab_size=4, num_shards=3)
embedding_lookup_result = (
tf.contrib.framework.safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, None).eval())
embedding_weights = list(itertools.chain(*embedding_weights))
self.assertAllClose(embedding_lookup_result, [
(embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4,
embedding_weights[2],
(embedding_weights[0] + embedding_weights[1]) / 2.0
])
def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
with self.test_session():
embedding_weights, sparse_ids, sparse_weights = (
self._random_ids_and_weights(vocab_size=4, num_shards=3))
embedding_weights[1] = embedding_weights[1].astype(np.float64)
self.assertRaises(ValueError,
tf.contrib.framework.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids)
embedding_weights = [
tf.constant(w, dtype=tf.float64) for w in embedding_weights]
self.assertRaises(
ValueError, tf.contrib.framework.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids, sparse_weights)
if __name__ == "__main__":
tf.test.main()