Adds a method to lookup embedding results, accounting for invalid IDs and empty features.
Change: 122271376
This commit is contained in:
parent
7057cc8a31
commit
f24b02735f
@ -7,6 +7,8 @@ exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_library(
|
||||
name = "framework_py",
|
||||
srcs = [
|
||||
@ -14,6 +16,7 @@ py_library(
|
||||
"python/framework/__init__.py",
|
||||
"python/framework/tensor_util.py",
|
||||
"python/ops/__init__.py",
|
||||
"python/ops/embedding_ops.py",
|
||||
"python/ops/ops.py",
|
||||
"python/ops/variables.py",
|
||||
],
|
||||
@ -56,6 +59,17 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "embedding_ops_test",
|
||||
size = "small",
|
||||
srcs = ["python/ops/embedding_ops_test.py"],
|
||||
additional_deps = [
|
||||
":framework_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -27,6 +27,7 @@
|
||||
@@is_strictly_increasing
|
||||
@@local_variable
|
||||
@@reduce_sum_n
|
||||
@@safe_embedding_lookup_sparse
|
||||
@@with_shape
|
||||
@@with_same_shape
|
||||
|
||||
|
@ -20,5 +20,6 @@ from __future__ import print_function
|
||||
|
||||
# TODO(ptucker): Add these to tf.contrib.variables?
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.framework.python.ops.embedding_ops import *
|
||||
from tensorflow.contrib.framework.python.ops.ops import *
|
||||
from tensorflow.contrib.framework.python.ops.variables import *
|
||||
|
118
tensorflow/contrib/framework/python/ops/embedding_ops.py
Normal file
118
tensorflow/contrib/framework/python/ops/embedding_ops.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Embedding functions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import embedding_ops as tf_embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
|
||||
|
||||
__all__ = ["safe_embedding_lookup_sparse",]
|
||||
|
||||
|
||||
def safe_embedding_lookup_sparse(
|
||||
embedding_weights, sparse_ids, sparse_weights=None, combiner="mean",
|
||||
default_id=None, name=None, partition_strategy="div"):
|
||||
"""Lookup embedding results, accounting for invalid IDs and empty features.
|
||||
|
||||
The partitioned embedding in `embedding_weights` must all be the same shape
|
||||
except for the first dimension. The first dimension is allowed to vary as the
|
||||
vocabulary size is not necessarily a multiple of `P`.
|
||||
|
||||
Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
|
||||
with non-positive weight. For an entry with no features, the embedding vector
|
||||
for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
|
||||
|
||||
Args:
|
||||
embedding_weights: A list of `P` float tensors or values representing
|
||||
partitioned embedding tensors.
|
||||
sparse_ids: `SparseTensor` of shape `[batch_size, ?]` containing the ids.
|
||||
sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
|
||||
float weights corresponding to `sparse_ids`, or `None` if all weights
|
||||
are be assumed to be 1.0.
|
||||
combiner: A string specifying how to combine embedding results for each
|
||||
entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
|
||||
the default.
|
||||
default_id: The id to use for an entry with no features.
|
||||
name: A name for this operation (optional).
|
||||
partition_strategy: A string specifying the partitioning strategy.
|
||||
Currently `"div"` and `"mod"` are supported. Default is `"div"`.
|
||||
|
||||
|
||||
Returns:
|
||||
Dense tensor of shape `[batch_size, embed_dim]`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `embedding_weights` is empty.
|
||||
"""
|
||||
if embedding_weights is None or len(embedding_weights) < 1:
|
||||
raise ValueError("Missing embedding_weights %s." % embedding_weights)
|
||||
|
||||
dtype = sparse_weights.dtype if sparse_weights else None
|
||||
embedding_weights = [
|
||||
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights]
|
||||
|
||||
contrib_tensor_util.assert_same_float_dtype(
|
||||
embedding_weights + [sparse_weights])
|
||||
|
||||
with ops.op_scope(
|
||||
embedding_weights + [sparse_ids, sparse_weights], name,
|
||||
"embedding_lookup") as scope:
|
||||
# Prune invalid ids and weights.
|
||||
sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
|
||||
|
||||
# Fill in dummy values for empty features, if necessary.
|
||||
sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(
|
||||
sparse_ids, default_id or 0)
|
||||
if sparse_weights:
|
||||
sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(
|
||||
sparse_weights, 1.0)
|
||||
|
||||
result = tf_embedding_ops.embedding_lookup_sparse(
|
||||
embedding_weights, sparse_ids, sparse_weights, combiner=combiner,
|
||||
partition_strategy=partition_strategy,
|
||||
name=None if default_id is None else scope)
|
||||
|
||||
if default_id is None:
|
||||
# Broadcast is_row_empty to the same shape as embedding_lookup_result,
|
||||
# for use in Select.
|
||||
is_row_empty = array_ops.tile(
|
||||
array_ops.reshape(is_row_empty, [-1, 1]),
|
||||
array_ops.pack([1, array_ops.shape(result)[1]]))
|
||||
|
||||
result = math_ops.select(
|
||||
is_row_empty, array_ops.zeros_like(result), result, name=scope)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _prune_invalid_ids(sparse_ids, sparse_weights=None,
|
||||
filter_invalid_weights=True):
|
||||
"""Prune invalid IDs (< 0) from the input ids and weights."""
|
||||
is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
|
||||
if sparse_weights and filter_invalid_weights:
|
||||
is_id_valid = math_ops.logical_and(
|
||||
is_id_valid, math_ops.greater(sparse_weights.values, 0))
|
||||
sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
|
||||
if sparse_weights:
|
||||
sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
|
||||
return sparse_ids, sparse_weights
|
147
tensorflow/contrib/framework/python/ops/embedding_ops_test.py
Normal file
147
tensorflow/contrib/framework/python/ops/embedding_ops_test.py
Normal file
@ -0,0 +1,147 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""embedding_ops tests."""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
|
||||
|
||||
def _random_ids_and_weights(self, vocab_size=4, embed_dim=4, num_shards=1):
|
||||
assert vocab_size > 0
|
||||
assert embed_dim > 0
|
||||
assert num_shards > 0
|
||||
assert num_shards <= vocab_size
|
||||
|
||||
embedding_weights = tf.create_partitioned_variables(
|
||||
shape=[vocab_size, embed_dim],
|
||||
slicing=[num_shards, 1],
|
||||
initializer=tf.truncated_normal_initializer(
|
||||
mean=0.0,
|
||||
stddev=1.0 / math.sqrt(vocab_size),
|
||||
dtype=tf.float32))
|
||||
for w in embedding_weights:
|
||||
w.initializer.run()
|
||||
embedding_weights = [w.eval() for w in embedding_weights]
|
||||
|
||||
# Each row demonstrates a test case:
|
||||
# Row 0: multiple valid ids, 1 invalid id, weighted mean
|
||||
# Row 1: all ids are invalid (leaving no valid ids after pruning)
|
||||
# Row 2: no ids to begin with
|
||||
# Row 3: single id
|
||||
# Row 4: all ids have <=0 weight
|
||||
indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [4, 0], [4, 1]]
|
||||
ids = [0, 1, -1, -1, 2, 0, 1]
|
||||
weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5]
|
||||
shape = [5, 4]
|
||||
|
||||
sparse_ids = tf.SparseTensor(
|
||||
tf.constant(indices, tf.int64),
|
||||
tf.constant(ids, tf.int64),
|
||||
tf.constant(shape, tf.int64))
|
||||
|
||||
sparse_weights = tf.SparseTensor(
|
||||
tf.constant(indices, tf.int64),
|
||||
tf.constant(weights, tf.float32),
|
||||
tf.constant(shape, tf.int64))
|
||||
|
||||
return embedding_weights, sparse_ids, sparse_weights
|
||||
|
||||
def test_safe_embedding_lookup_sparse_return_zero_vector(self):
|
||||
with self.test_session():
|
||||
embedding_weights, sparse_ids, sparse_weights = (
|
||||
self._random_ids_and_weights())
|
||||
|
||||
embedding_lookup_result = (
|
||||
tf.contrib.framework.safe_embedding_lookup_sparse(
|
||||
embedding_weights, sparse_ids, sparse_weights).eval())
|
||||
|
||||
self.assertAllClose(embedding_lookup_result, [
|
||||
(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0,
|
||||
[0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4
|
||||
])
|
||||
|
||||
def test_safe_embedding_lookup_sparse_return_special_vector(self):
|
||||
with self.test_session():
|
||||
embedding_weights, sparse_ids, sparse_weights = (
|
||||
self._random_ids_and_weights())
|
||||
|
||||
embedding_lookup_result = (
|
||||
tf.contrib.framework.safe_embedding_lookup_sparse(
|
||||
embedding_weights, sparse_ids,
|
||||
sparse_weights, default_id=3).eval())
|
||||
|
||||
self.assertAllClose(embedding_lookup_result, [
|
||||
(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0,
|
||||
embedding_weights[0][3], embedding_weights[0][3],
|
||||
embedding_weights[0][2], embedding_weights[0][3]])
|
||||
|
||||
def test_safe_embedding_lookup_sparse_no_weights(self):
|
||||
with self.test_session():
|
||||
embedding_weights, sparse_ids, _ = self._random_ids_and_weights()
|
||||
|
||||
embedding_lookup_result = (
|
||||
tf.contrib.framework.safe_embedding_lookup_sparse(
|
||||
embedding_weights, sparse_ids, None).eval())
|
||||
|
||||
self.assertAllClose(embedding_lookup_result, [
|
||||
(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
|
||||
[0] * 4, embedding_weights[0][2],
|
||||
(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0
|
||||
])
|
||||
|
||||
def test_safe_embedding_lookup_sparse_partitioned(self):
|
||||
with self.test_session():
|
||||
embedding_weights, sparse_ids, _ = self._random_ids_and_weights(
|
||||
vocab_size=4, num_shards=3)
|
||||
|
||||
embedding_lookup_result = (
|
||||
tf.contrib.framework.safe_embedding_lookup_sparse(
|
||||
embedding_weights, sparse_ids, None).eval())
|
||||
|
||||
embedding_weights = list(itertools.chain(*embedding_weights))
|
||||
self.assertAllClose(embedding_lookup_result, [
|
||||
(embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4,
|
||||
embedding_weights[2],
|
||||
(embedding_weights[0] + embedding_weights[1]) / 2.0
|
||||
])
|
||||
|
||||
def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
|
||||
with self.test_session():
|
||||
embedding_weights, sparse_ids, sparse_weights = (
|
||||
self._random_ids_and_weights(vocab_size=4, num_shards=3))
|
||||
|
||||
embedding_weights[1] = embedding_weights[1].astype(np.float64)
|
||||
self.assertRaises(ValueError,
|
||||
tf.contrib.framework.safe_embedding_lookup_sparse,
|
||||
embedding_weights, sparse_ids)
|
||||
embedding_weights = [
|
||||
tf.constant(w, dtype=tf.float64) for w in embedding_weights]
|
||||
self.assertRaises(
|
||||
ValueError, tf.contrib.framework.safe_embedding_lookup_sparse,
|
||||
embedding_weights, sparse_ids, sparse_weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
Loading…
Reference in New Issue
Block a user