From f24b02735feca015dbeb75f4c9b3eba16bdb134e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 13 May 2016 09:25:32 -0800 Subject: [PATCH] Adds a method to lookup embedding results, accounting for invalid IDs and empty features. Change: 122271376 --- tensorflow/contrib/framework/BUILD | 14 ++ tensorflow/contrib/framework/__init__.py | 1 + .../contrib/framework/python/ops/__init__.py | 1 + .../framework/python/ops/embedding_ops.py | 118 ++++++++++++++ .../python/ops/embedding_ops_test.py | 147 ++++++++++++++++++ 5 files changed, 281 insertions(+) create mode 100644 tensorflow/contrib/framework/python/ops/embedding_ops.py create mode 100644 tensorflow/contrib/framework/python/ops/embedding_ops_test.py diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index f1b9df6cbee..c8cbbc40bbd 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -7,6 +7,8 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + py_library( name = "framework_py", srcs = [ @@ -14,6 +16,7 @@ py_library( "python/framework/__init__.py", "python/framework/tensor_util.py", "python/ops/__init__.py", + "python/ops/embedding_ops.py", "python/ops/ops.py", "python/ops/variables.py", ], @@ -56,6 +59,17 @@ py_test( ], ) +cuda_py_test( + name = "embedding_ops_test", + size = "small", + srcs = ["python/ops/embedding_ops_test.py"], + additional_deps = [ + ":framework_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 7b00a6c24e0..14f5747a4d3 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -27,6 +27,7 @@ @@is_strictly_increasing @@local_variable @@reduce_sum_n +@@safe_embedding_lookup_sparse @@with_shape @@with_same_shape diff --git a/tensorflow/contrib/framework/python/ops/__init__.py b/tensorflow/contrib/framework/python/ops/__init__.py index 62aa8b1b187..6c3137cbf77 100644 --- a/tensorflow/contrib/framework/python/ops/__init__.py +++ b/tensorflow/contrib/framework/python/ops/__init__.py @@ -20,5 +20,6 @@ from __future__ import print_function # TODO(ptucker): Add these to tf.contrib.variables? # pylint: disable=wildcard-import +from tensorflow.contrib.framework.python.ops.embedding_ops import * from tensorflow.contrib.framework.python.ops.ops import * from tensorflow.contrib.framework.python.ops.variables import * diff --git a/tensorflow/contrib/framework/python/ops/embedding_ops.py b/tensorflow/contrib/framework/python/ops/embedding_ops.py new file mode 100644 index 00000000000..9e8c94fcddb --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/embedding_ops.py @@ -0,0 +1,118 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Embedding functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops as tf_embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops + + +__all__ = ["safe_embedding_lookup_sparse",] + + +def safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights=None, combiner="mean", + default_id=None, name=None, partition_strategy="div"): + """Lookup embedding results, accounting for invalid IDs and empty features. + + The partitioned embedding in `embedding_weights` must all be the same shape + except for the first dimension. The first dimension is allowed to vary as the + vocabulary size is not necessarily a multiple of `P`. + + Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs + with non-positive weight. For an entry with no features, the embedding vector + for `default_id` is returned, or the 0-vector if `default_id` is not supplied. + + Args: + embedding_weights: A list of `P` float tensors or values representing + partitioned embedding tensors. + sparse_ids: `SparseTensor` of shape `[batch_size, ?]` containing the ids. + sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing + float weights corresponding to `sparse_ids`, or `None` if all weights + are be assumed to be 1.0. + combiner: A string specifying how to combine embedding results for each + entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" + the default. + default_id: The id to use for an entry with no features. + name: A name for this operation (optional). + partition_strategy: A string specifying the partitioning strategy. + Currently `"div"` and `"mod"` are supported. Default is `"div"`. + + + Returns: + Dense tensor of shape `[batch_size, embed_dim]`. + + Raises: + ValueError: if `embedding_weights` is empty. + """ + if embedding_weights is None or len(embedding_weights) < 1: + raise ValueError("Missing embedding_weights %s." % embedding_weights) + + dtype = sparse_weights.dtype if sparse_weights else None + embedding_weights = [ + ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights] + + contrib_tensor_util.assert_same_float_dtype( + embedding_weights + [sparse_weights]) + + with ops.op_scope( + embedding_weights + [sparse_ids, sparse_weights], name, + "embedding_lookup") as scope: + # Prune invalid ids and weights. + sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) + + # Fill in dummy values for empty features, if necessary. + sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows( + sparse_ids, default_id or 0) + if sparse_weights: + sparse_weights, _ = sparse_ops.sparse_fill_empty_rows( + sparse_weights, 1.0) + + result = tf_embedding_ops.embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights, combiner=combiner, + partition_strategy=partition_strategy, + name=None if default_id is None else scope) + + if default_id is None: + # Broadcast is_row_empty to the same shape as embedding_lookup_result, + # for use in Select. + is_row_empty = array_ops.tile( + array_ops.reshape(is_row_empty, [-1, 1]), + array_ops.pack([1, array_ops.shape(result)[1]])) + + result = math_ops.select( + is_row_empty, array_ops.zeros_like(result), result, name=scope) + + return result + + +def _prune_invalid_ids(sparse_ids, sparse_weights=None, + filter_invalid_weights=True): + """Prune invalid IDs (< 0) from the input ids and weights.""" + is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) + if sparse_weights and filter_invalid_weights: + is_id_valid = math_ops.logical_and( + is_id_valid, math_ops.greater(sparse_weights.values, 0)) + sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) + if sparse_weights: + sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) + return sparse_ids, sparse_weights diff --git a/tensorflow/contrib/framework/python/ops/embedding_ops_test.py b/tensorflow/contrib/framework/python/ops/embedding_ops_test.py new file mode 100644 index 00000000000..8fff73cb552 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/embedding_ops_test.py @@ -0,0 +1,147 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""embedding_ops tests.""" + +# pylint: disable=unused-import +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import math + +import numpy as np +import tensorflow as tf + + +class SafeEmbeddingLookupSparseTest(tf.test.TestCase): + + def _random_ids_and_weights(self, vocab_size=4, embed_dim=4, num_shards=1): + assert vocab_size > 0 + assert embed_dim > 0 + assert num_shards > 0 + assert num_shards <= vocab_size + + embedding_weights = tf.create_partitioned_variables( + shape=[vocab_size, embed_dim], + slicing=[num_shards, 1], + initializer=tf.truncated_normal_initializer( + mean=0.0, + stddev=1.0 / math.sqrt(vocab_size), + dtype=tf.float32)) + for w in embedding_weights: + w.initializer.run() + embedding_weights = [w.eval() for w in embedding_weights] + + # Each row demonstrates a test case: + # Row 0: multiple valid ids, 1 invalid id, weighted mean + # Row 1: all ids are invalid (leaving no valid ids after pruning) + # Row 2: no ids to begin with + # Row 3: single id + # Row 4: all ids have <=0 weight + indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [4, 0], [4, 1]] + ids = [0, 1, -1, -1, 2, 0, 1] + weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5] + shape = [5, 4] + + sparse_ids = tf.SparseTensor( + tf.constant(indices, tf.int64), + tf.constant(ids, tf.int64), + tf.constant(shape, tf.int64)) + + sparse_weights = tf.SparseTensor( + tf.constant(indices, tf.int64), + tf.constant(weights, tf.float32), + tf.constant(shape, tf.int64)) + + return embedding_weights, sparse_ids, sparse_weights + + def test_safe_embedding_lookup_sparse_return_zero_vector(self): + with self.test_session(): + embedding_weights, sparse_ids, sparse_weights = ( + self._random_ids_and_weights()) + + embedding_lookup_result = ( + tf.contrib.framework.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights).eval()) + + self.assertAllClose(embedding_lookup_result, [ + (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0, + [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4 + ]) + + def test_safe_embedding_lookup_sparse_return_special_vector(self): + with self.test_session(): + embedding_weights, sparse_ids, sparse_weights = ( + self._random_ids_and_weights()) + + embedding_lookup_result = ( + tf.contrib.framework.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, + sparse_weights, default_id=3).eval()) + + self.assertAllClose(embedding_lookup_result, [ + (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0, + embedding_weights[0][3], embedding_weights[0][3], + embedding_weights[0][2], embedding_weights[0][3]]) + + def test_safe_embedding_lookup_sparse_no_weights(self): + with self.test_session(): + embedding_weights, sparse_ids, _ = self._random_ids_and_weights() + + embedding_lookup_result = ( + tf.contrib.framework.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None).eval()) + + self.assertAllClose(embedding_lookup_result, [ + (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, + [0] * 4, embedding_weights[0][2], + (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0 + ]) + + def test_safe_embedding_lookup_sparse_partitioned(self): + with self.test_session(): + embedding_weights, sparse_ids, _ = self._random_ids_and_weights( + vocab_size=4, num_shards=3) + + embedding_lookup_result = ( + tf.contrib.framework.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None).eval()) + + embedding_weights = list(itertools.chain(*embedding_weights)) + self.assertAllClose(embedding_lookup_result, [ + (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4, + embedding_weights[2], + (embedding_weights[0] + embedding_weights[1]) / 2.0 + ]) + + def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self): + with self.test_session(): + embedding_weights, sparse_ids, sparse_weights = ( + self._random_ids_and_weights(vocab_size=4, num_shards=3)) + + embedding_weights[1] = embedding_weights[1].astype(np.float64) + self.assertRaises(ValueError, + tf.contrib.framework.safe_embedding_lookup_sparse, + embedding_weights, sparse_ids) + embedding_weights = [ + tf.constant(w, dtype=tf.float64) for w in embedding_weights] + self.assertRaises( + ValueError, tf.contrib.framework.safe_embedding_lookup_sparse, + embedding_weights, sparse_ids, sparse_weights) + + +if __name__ == "__main__": + tf.test.main()