From 025e65591ef201118f1490c28f0f5df89cbf5104 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 29 Jun 2016 16:58:19 -0800
Subject: [PATCH] Add a new hashed embedding column Change: 126254010

---
 tensorflow/contrib/framework/BUILD            |  11 -
 tensorflow/contrib/layers/BUILD               |  13 ++
 .../contrib/layers/python/layers/__init__.py  |   1 +
 .../layers/python/layers/embedding_ops.py     | 191 ++++++++++++++++++
 .../python/layers}/embedding_ops_test.py      | 135 +++++++++++--
 .../layers/python/layers/feature_column.py    | 105 +++++++++-
 .../python/layers/feature_column_ops_test.py  |  21 ++
 7 files changed, 443 insertions(+), 34 deletions(-)
 create mode 100644 tensorflow/contrib/layers/python/layers/embedding_ops.py
 rename tensorflow/contrib/{framework/python/ops => layers/python/layers}/embedding_ops_test.py (63%)

diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD
index b13de8e53f5..cb4749c88e7 100644
--- a/tensorflow/contrib/framework/BUILD
+++ b/tensorflow/contrib/framework/BUILD
@@ -77,17 +77,6 @@ py_test(
     ],
 )
 
-cuda_py_test(
-    name = "embedding_ops_test",
-    size = "small",
-    srcs = ["python/ops/embedding_ops_test.py"],
-    additional_deps = [
-        ":framework_py",
-        "//tensorflow/python:framework_test_lib",
-        "//tensorflow/python:platform_test",
-    ],
-)
-
 cuda_py_test(
     name = "sampling_ops_test",
     size = "small",
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index f4827c06465..bd05720b264 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -213,6 +213,19 @@ py_test(
     ],
 )
 
+py_test(
+    name = "embedding_ops_test",
+    size = "small",
+    srcs = ["python/layers/embedding_ops_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":layers_py",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
 filegroup(
     name = "all_files",
     srcs = glob(
diff --git a/tensorflow/contrib/layers/python/layers/__init__.py b/tensorflow/contrib/layers/python/layers/__init__.py
index 239aee7a3d2..f1f3b52a50c 100644
--- a/tensorflow/contrib/layers/python/layers/__init__.py
+++ b/tensorflow/contrib/layers/python/layers/__init__.py
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 # pylint: disable=wildcard-import
+from tensorflow.contrib.layers.python.layers.embedding_ops import *
 from tensorflow.contrib.layers.python.layers.feature_column import *
 from tensorflow.contrib.layers.python.layers.feature_column_ops import *
 from tensorflow.contrib.layers.python.layers.initializers import *
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py
new file mode 100644
index 00000000000..4904c16a9cd
--- /dev/null
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py
@@ -0,0 +1,191 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Embedding functions."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.framework.python.ops import embedding_ops as contrib_embedding_ops
+from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
+from tensorflow.python.framework import dtypes
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sparse_ops
+
+__all__ = ["safe_embedding_lookup_sparse", "hashed_embedding_lookup",
+           "hashed_embedding_lookup_sparse"]
+
+
+# TODO(chapelle): move the safe_embedding_lookup_sparse code here (b/29826543)
+safe_embedding_lookup_sparse = contrib_embedding_ops.safe_embedding_lookup_sparse  # pylint: disable=line-too-long
+
+
+def hashed_embedding_lookup(params, values, dimension, name=None):
+  """Looks up embeddings using parameter hashing for each value in `values`.
+
+  The i-th embedding component of a value v in `values` is found by retrieving
+  the weight whose index is a fingerprint of the pair (v,i).
+  The concept is explored as "feature hashing" for model compression in this
+  paper: http://arxiv.org/pdf/1504.04788.pdf
+
+  Feature hashing has the pleasant effect of allowing us to compute an embedding
+  without needing a pre-determined vocabulary, relieving some amount of process
+  complexity. It also allows for us to maintain embeddings for possibly
+  trillions of features with a fixed amount of memory.
+
+  Note that this is superior to out-of-vocabulary shared "hash buckets" in that
+  the embedding is extremely likely to be unique for each token as opposed to
+  being shared across probably-colliding tokens. The price is that we must
+  compute a hash once for each scalar in the token's embedding as opposed to
+  once per token.
+
+  If `params` is a list, it represents a partition of the embedding parameters.
+  Each tensor in the list should have the same length, except for the first ones
+  which may have an additional element. For instance 10 parameters can be
+  partitioned in 4 tensors with length `[3, 3, 2, 2]`.
+
+  Args:
+    params: A `Tensor` or `list` of `Tensors`.
+      Each tensor must be of rank 1 with fully-defined shape.
+    values: `Tensor` of values to be embedded.
+    dimension: Embedding dimension
+    name: An optional name for this op.
+
+  Returns:
+    A tensor with shape [d0, ..., dn, dimension]
+      with shape(values) = [d0, ..., dn]
+
+  Raises:
+    ValueError: if dimension is not positive or the partition size is invalid.
+  """
+  if not isinstance(params, list):
+    params = [params]
+
+  with ops.op_scope(params + [dimension, values], name,
+                    "hashed_embedding_lookup"):
+    if dimension <= 0:
+      raise ValueError("Dimension should be >0 not %d" % dimension)
+
+    num_partitions = len(params)
+    partition_sizes = []
+    for p in range(num_partitions):
+      shape = params[p].get_shape()
+      shape.assert_has_rank(1)
+      shape.assert_is_fully_defined()
+      partition_sizes.append(shape[0].value)
+    num_params = sum(partition_sizes)  # Total number of parameters.
+
+    # Assert the size of each partition.
+    for p in range(num_partitions):
+      expected_size = (num_params - p - 1) // num_partitions + 1
+      if partition_sizes[p] != expected_size:
+        raise ValueError("Tensor %d in params has size %d, expected %d." %
+                         (p, partition_sizes[p], expected_size))
+
+    # Flatten the values
+    values_shape = array_ops.shape(values)
+    values = array_ops.reshape(values, [-1, 1])
+
+    # With two values v1 and v2 and 3 dimensions, we will cross
+    # [[0, 1, 2], [0, 1, 2]] with [[v1], [v2]].
+    tensors_to_cross = [array_ops.tile(array_ops.expand_dims(
+        math_ops.range(0, dimension), 0), array_ops.shape(values)), values]
+    ids = sparse_feature_cross_op.sparse_feature_cross(
+        tensors_to_cross, hashed_output=True, num_buckets=num_params)
+    ids = sparse_ops.sparse_tensor_to_dense(ids)
+
+    # No need to validate the indices since we have checked the params
+    # dimensions and we know the largest id.
+    result = embedding_ops.embedding_lookup(
+        params, ids, partition_strategy="div", validate_indices=False)
+
+    return array_ops.reshape(result, array_ops.concat(
+        0, [values_shape, [dimension]]))
+
+
+def hashed_embedding_lookup_sparse(params,
+                                   sparse_values,
+                                   dimension,
+                                   combiner="mean",
+                                   default_value=None,
+                                   name=None):
+  """Looks up embeddings of a sparse feature using parameter hashing.
+
+  See `tf.contrib.layers.hashed_embedding_lookup` for embedding with hashing.
+
+  Args:
+    params: A `Tensor` or `list` of `Tensors`.
+      Each tensor must be of rank 1 with fully-defined shape.
+    sparse_values: A 2-D `SparseTensor` containing the values to be embedded.
+      Some rows may be empty.
+    dimension: Embedding dimension
+    combiner: A string specifying how to combine embedding results for each
+        entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
+        the default.
+    default_value: The value to use for an entry with no features.
+    name: An optional name for this op.
+
+  Returns:
+     Dense tensor with shape [N, dimension] with N the number of rows in
+       sparse_values.
+
+  Raises:
+    TypeError: If sparse_values is not a SparseTensor.
+    ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
+  """
+
+  if not isinstance(params, list):
+    params = [params]
+  if not isinstance(sparse_values, ops.SparseTensor):
+    raise TypeError("sparse_values must be SparseTensor")
+
+  with ops.op_scope(params + [sparse_values], name,
+                    "hashed_sparse_embedding_lookup") as scope:
+    # Fill in the empty rows.
+    if default_value is None:
+      # Random default values to reduce the risk of collision.
+      if sparse_values.dtype == dtypes.string:
+        default_value = "6ZxWzWOHxZ"
+      else:
+        default_value = 1288896567
+    sparse_values, _ = sparse_ops.sparse_fill_empty_rows(
+        sparse_values, default_value)
+
+    segment_ids = sparse_values.indices[:, 0]
+    if segment_ids.dtype != dtypes.int32:
+      segment_ids = math_ops.cast(segment_ids, dtypes.int32)
+
+    values = sparse_values.values
+    values, idx = array_ops.unique(values)
+
+    embeddings = hashed_embedding_lookup(params, values, dimension)
+
+    if combiner == "sum":
+      embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
+                                               name=scope)
+    elif combiner == "mean":
+      embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
+                                                name=scope)
+    elif combiner == "sqrtn":
+      embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, segment_ids,
+                                                  name=scope)
+    else:
+      raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")
+
+    return embeddings
diff --git a/tensorflow/contrib/framework/python/ops/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
similarity index 63%
rename from tensorflow/contrib/framework/python/ops/embedding_ops_test.py
rename to tensorflow/contrib/layers/python/layers/embedding_ops_test.py
index 78d4dabdc4e..8d2103207ec 100644
--- a/tensorflow/contrib/framework/python/ops/embedding_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
@@ -100,7 +100,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
       sparse_ids, sparse_weights = self._ids_and_weights_2d()
 
       embedding_lookup_result = (
-          tf.contrib.framework.safe_embedding_lookup_sparse(
+          tf.contrib.layers.safe_embedding_lookup_sparse(
               embedding_weights, sparse_ids, sparse_weights).eval())
 
       self.assertAllClose(embedding_lookup_result, [
@@ -114,7 +114,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
       sparse_ids, sparse_weights = self._ids_and_weights_2d()
 
       embedding_lookup_result = (
-          tf.contrib.framework.safe_embedding_lookup_sparse(
+          tf.contrib.layers.safe_embedding_lookup_sparse(
               embedding_weights,
               sparse_ids,
               sparse_weights,
@@ -132,9 +132,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
       sparse_ids, _ = self._ids_and_weights_2d()
 
       embedding_lookup_result = (
-          tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights,
-                                                            sparse_ids,
-                                                            None).eval())
+          tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights,
+                                                         sparse_ids,
+                                                         None).eval())
 
       self.assertAllClose(embedding_lookup_result, [
           (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
@@ -148,9 +148,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
       sparse_ids, _ = self._ids_and_weights_2d()
 
       embedding_lookup_result = (
-          tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights,
-                                                            sparse_ids,
-                                                            None).eval())
+          tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights,
+                                                         sparse_ids,
+                                                         None).eval())
 
       embedding_weights = list(itertools.chain(*embedding_weights))
       self.assertAllClose(embedding_lookup_result, [
@@ -166,13 +166,13 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
 
       embedding_weights[1] = embedding_weights[1].astype(np.float64)
       self.assertRaises(ValueError,
-                        tf.contrib.framework.safe_embedding_lookup_sparse,
+                        tf.contrib.layers.safe_embedding_lookup_sparse,
                         embedding_weights, sparse_ids)
       embedding_weights = [
           tf.constant(w, dtype=tf.float64) for w in embedding_weights
       ]
       self.assertRaises(ValueError,
-                        tf.contrib.framework.safe_embedding_lookup_sparse,
+                        tf.contrib.layers.safe_embedding_lookup_sparse,
                         embedding_weights, sparse_ids, sparse_weights)
 
   def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
@@ -181,7 +181,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
       sparse_ids, sparse_weights = self._ids_and_weights_3d()
 
       embedding_lookup_result = (
-          tf.contrib.framework.safe_embedding_lookup_sparse(
+          tf.contrib.layers.safe_embedding_lookup_sparse(
               embedding_weights, sparse_ids, sparse_weights).eval())
 
       self.assertAllClose(embedding_lookup_result, [
@@ -195,7 +195,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
       sparse_ids, sparse_weights = self._ids_and_weights_3d()
 
       embedding_lookup_result = (
-          tf.contrib.framework.safe_embedding_lookup_sparse(
+          tf.contrib.layers.safe_embedding_lookup_sparse(
               embedding_weights,
               sparse_ids,
               sparse_weights,
@@ -214,9 +214,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
       sparse_ids, _ = self._ids_and_weights_3d()
 
       embedding_lookup_result = (
-          tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights,
-                                                            sparse_ids,
-                                                            None).eval())
+          tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights,
+                                                         sparse_ids,
+                                                         None).eval())
 
       self.assertAllClose(embedding_lookup_result, [
           [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
@@ -231,9 +231,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
       sparse_ids, _ = self._ids_and_weights_3d()
 
       embedding_lookup_result = (
-          tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights,
-                                                            sparse_ids,
-                                                            None).eval())
+          tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights,
+                                                         sparse_ids,
+                                                         None).eval())
 
       embedding_weights = list(itertools.chain(*embedding_weights))
       self.assertAllClose(embedding_lookup_result, [
@@ -251,15 +251,110 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase):
 
       embedding_weights[1] = embedding_weights[1].astype(np.float64)
       self.assertRaises(ValueError,
-                        tf.contrib.framework.safe_embedding_lookup_sparse,
+                        tf.contrib.layers.safe_embedding_lookup_sparse,
                         embedding_weights, sparse_ids)
       embedding_weights = [
           tf.constant(w, dtype=tf.float64) for w in embedding_weights
       ]
       self.assertRaises(ValueError,
-                        tf.contrib.framework.safe_embedding_lookup_sparse,
+                        tf.contrib.layers.safe_embedding_lookup_sparse,
                         embedding_weights, sparse_ids, sparse_weights)
 
 
+class HashedEmbeddingLookupTest(tf.test.TestCase):
+
+  def setUp(self):
+    tf.set_random_seed(1)
+
+  def _random_weights(self, size=50, num_shards=1):
+    assert size > 0
+    assert num_shards > 0
+    assert num_shards <= size
+
+    embedding_weights = tf.create_partitioned_variables(
+        shape=[size],
+        slicing=[num_shards],
+        initializer=tf.truncated_normal_initializer(mean=0.0,
+                                                    stddev=1.0,
+                                                    dtype=tf.float32))
+    for w in embedding_weights:
+      w.initializer.run()
+    return embedding_weights
+
+  def test_hashed_embedding_consistency(self):
+    with self.test_session():
+      embedding_weights = self._random_weights()
+      values = tf.constant(["foo", "foo"])
+
+      embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup(
+          embedding_weights, values, dimension=10).eval()
+
+      self.assertAllEqual(embedding_lookup_result.shape, [2, 10])
+      self.assertAllEqual(embedding_lookup_result[0],
+                          embedding_lookup_result[1])
+
+  def test_hashed_embedding_multiple_partition(self):
+    with self.test_session():
+      embedding_weights = self._random_weights(num_shards=7)
+      values = tf.constant([4, 4, 5])
+
+      embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup(
+          embedding_weights, values, dimension=5).eval()
+
+      self.assertAllEqual(embedding_lookup_result.shape, [3, 5])
+      self.assertAllEqual(embedding_lookup_result[0],
+                          embedding_lookup_result[1])
+      # Different embedding expected for different value.
+      embedding_diff = np.min((embedding_lookup_result[2] -
+                               embedding_lookup_result[0]) ** 2)
+      self.assertGreater(embedding_diff, 0)
+
+  def test_hashed_embedding_coverage(self):
+    with self.test_session():
+      size = 8
+      embedding_weights = self._random_weights(size=size, num_shards=3)
+      values = tf.constant(["foo"])
+
+      # Large embedding dimension to cover the full range of weights.
+      embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup(
+          embedding_weights, values, dimension=100).eval()
+
+      self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)
+
+  def test_hashed_embedding_multi_dimension(self):
+    with self.test_session():
+      embedding_weights = self._random_weights()
+      values = tf.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]])
+
+      embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup(
+          embedding_weights, values, dimension=10).eval()
+
+      self.assertAllEqual(embedding_lookup_result.shape, [2, 3, 10])
+      self.assertAllEqual(embedding_lookup_result[0][0],
+                          embedding_lookup_result[1][2])
+
+  def test_hashed_embedding_lookup_sparse(self):
+    with self.test_session():
+      embedding_weights = self._random_weights(num_shards=3)
+      sparse_tensor = tf.SparseTensor(values=["foo", "bar", "foo", "bar"],
+                                      indices=[[0, 0], [1, 0], [1, 1], [3, 0]],
+                                      shape=[5, 2])
+
+      embedding_lookup_result = (
+          tf.contrib.layers.hashed_embedding_lookup_sparse(
+              embedding_weights, sparse_tensor, dimension=5, combiner="mean")
+          .eval())
+
+      self.assertAllEqual(embedding_lookup_result.shape, [5, 5])
+      # Same non-zero embedding for the empty rows filled with a default value.
+      self.assertAllEqual(embedding_lookup_result[2],
+                          embedding_lookup_result[4])
+      embedding_norm = np.sum(embedding_lookup_result[2] ** 2)
+      self.assertGreater(embedding_norm, 0)
+
+      self.assertAllEqual(embedding_lookup_result[1],
+                          0.5 * (embedding_lookup_result[0] +
+                                 embedding_lookup_result[3]))
+
 if __name__ == "__main__":
   tf.test.main()
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index c05aa6c2917..ba19d773555 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -75,7 +75,7 @@ import abc
 import collections
 import math
 
-from tensorflow.contrib.framework.python.ops import embedding_ops as contrib_embedding_ops
+from tensorflow.contrib.layers.python.layers import embedding_ops
 from tensorflow.contrib.layers.python.ops import bucketization_op
 from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
 from tensorflow.contrib.lookup import lookup_ops as contrib_lookup_ops
@@ -517,7 +517,7 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
 
     # This is effectively the same format as str(self), except with our special
     # treatment.
-    return "_EmbeddingColumn(%s)" % ", ".join(fields_values)
+    return "%s(%s)" % (type(self).__name__, ", ".join(fields_values))
 
   def insert_transformed_feature(self, columns_to_tensors):
     self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
@@ -576,6 +576,105 @@ def embedding_column(sparse_id_column,
   return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer)
 
 
+class _HashedEmbeddingColumn(collections.namedtuple(
+    "_HashedEmbeddingColumn", ["column_name", "size", "dimension", "combiner",
+                               "initializer"]), _EmbeddingColumn):
+  """See `hashed_embedding_column`."""
+
+  def __new__(cls,
+              column_name,
+              size,
+              dimension,
+              combiner="mean",
+              initializer=None):
+    if initializer is not None and not callable(initializer):
+      raise ValueError("initializer must be callable if specified.")
+    if initializer is None:
+      stddev = 0.1
+      # TODO(b/25671353): Better initial value?
+      initializer = init_ops.truncated_normal_initializer(mean=0.0,
+                                                          stddev=stddev)
+    return super(_HashedEmbeddingColumn, cls).__new__(cls, column_name, size,
+                                                      dimension, combiner,
+                                                      initializer)
+
+  @property
+  def name(self):
+    return self.column_name + "_embedding"
+
+  @property
+  def config(self):
+    return {self.column_name: parsing_ops.VarLenFeature(dtypes.string)}
+
+  def insert_transformed_feature(self, columns_to_tensors):
+    columns_to_tensors[self] = columns_to_tensors[self.column_name]
+
+  def to_dnn_input_layer(self,
+                         input_tensor,
+                         weight_collections=None,
+                         trainable=True):
+    # Same heuristic for the number of shards as _max_size_embedding_partitioner
+    max_shard_bytes = (64 << 20) - 1
+    shards = self.size * 4.0  / max_shard_bytes
+    shards = max(1, int(math.ceil(shards)))
+
+    embeddings = partitioned_variables.create_partitioned_variables(
+        shape=[self.size],
+        slicing=[shards],
+        initializer=self.initializer,
+        dtype=dtypes.float32,
+        collections=_add_variable_collection(weight_collections),
+        name=self.name + "_weights",
+        reuse=False,
+        trainable=trainable)
+
+    return embedding_ops.hashed_embedding_lookup_sparse(
+        embeddings, input_tensor, self.dimension, name=self.name + "_lookup")
+
+
+def hashed_embedding_column(column_name,
+                            size,
+                            dimension,
+                            combiner="mean",
+                            initializer=None):
+  """Creates an embedding column of a sparse feature using parameter hashing.
+
+  The i-th embedding component of a value v is found by retrieving an
+  embedding weight whose index is a fingerprint of the pair (v,i).
+
+  Args:
+    column_name: A string defining sparse column name.
+    size: An integer specifying the number of parameters in the embedding layer.
+    dimension: An integer specifying dimension of the embedding.
+    combiner: A string specifying how to reduce if there are multiple entries
+      in a single row. Currently "mean", "sqrtn" and "sum" are supported. Each
+      of this can be thought as example level normalizations on the column:
+        * "sum": do not normalize features in the column
+        * "mean": do l1 normalization on features in the column
+        * "sqrtn": do l2 normalization on features in the column
+      For more information: `tf.embedding_lookup_sparse`.
+    initializer: A variable initializer function to be used in embedding
+      variable initialization. If not specified, defaults to
+      `tf.truncated_normal_initializer` with mean 0 and standard deviation 0.1.
+
+  Returns:
+    A _HashedEmbeddingColumn.
+
+  Raises:
+    ValueError: if dimension or size is not a positive integer; or if combiner
+      is not supported.
+
+  """
+  if (dimension < 1) or (size < 1):
+    raise ValueError("Dimension and size must be greater than 0.")
+
+  if combiner not in ("mean", "sqrtn", "sum"):
+    raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")
+
+  return _HashedEmbeddingColumn(column_name, size, dimension, combiner,
+                                initializer)
+
+
 class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
     "_RealValuedColumn",
     ["column_name", "dimension", "default_value", "dtype"])):
@@ -1237,7 +1336,7 @@ def _create_embedding_lookup(input_tensor, vocab_size, dimension,
       reuse=False,
       trainable=trainable)
 
-  return contrib_embedding_ops.safe_embedding_lookup_sparse(
+  return embedding_ops.safe_embedding_lookup_sparse(
       embeddings,
       input_tensor,
       default_id=0,
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index 60516ddae59..1d0f45357ed 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -282,6 +282,27 @@ class InputLayerTest(tf.test.TestCase):
       tf.initialize_all_variables().run()
       self.assertAllEqual(output.eval().shape, [2, 10])
 
+  def testHashedEmbeddingColumn(self):
+    wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo", "omar"],
+                                  indices=[[0, 0], [1, 0], [1, 1], [2, 0]],
+                                  shape=[3, 2])
+
+    features = {"wire": wire_tensor}
+    # Big enough hash space so that hopefully there is no collision
+    embedded_sparse = tf.contrib.layers.hashed_embedding_column("wire", 1000, 3)
+    output = tf.contrib.layers.input_from_feature_columns(
+        features, [embedded_sparse], weight_collections=["my_collection"])
+    weights = tf.get_collection("my_collection")
+    grad = tf.gradients(output, weights)
+    with self.test_session():
+      tf.initialize_all_variables().run()
+      gradient_values = []
+      # Collect the gradient from the different partitions (one in this test)
+      for p in range(len(grad)):
+        gradient_values.extend(grad[p].values.eval())
+      gradient_values.sort()
+      self.assertAllEqual(gradient_values, [0.5]*6 + [2]*3)
+
   def testEmbeddingColumnWithInitializer(self):
     hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
     wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],