diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index b13de8e53f5..cb4749c88e7 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -77,17 +77,6 @@ py_test( ], ) -cuda_py_test( - name = "embedding_ops_test", - size = "small", - srcs = ["python/ops/embedding_ops_test.py"], - additional_deps = [ - ":framework_py", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - cuda_py_test( name = "sampling_ops_test", size = "small", diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index f4827c06465..bd05720b264 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -213,6 +213,19 @@ py_test( ], ) +py_test( + name = "embedding_ops_test", + size = "small", + srcs = ["python/layers/embedding_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":layers_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/layers/python/layers/__init__.py b/tensorflow/contrib/layers/python/layers/__init__.py index 239aee7a3d2..f1f3b52a50c 100644 --- a/tensorflow/contrib/layers/python/layers/__init__.py +++ b/tensorflow/contrib/layers/python/layers/__init__.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import +from tensorflow.contrib.layers.python.layers.embedding_ops import * from tensorflow.contrib.layers.python.layers.feature_column import * from tensorflow.contrib.layers.python.layers.feature_column_ops import * from tensorflow.contrib.layers.python.layers.initializers import * diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py new file mode 100644 index 00000000000..4904c16a9cd --- /dev/null +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -0,0 +1,191 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Embedding functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.ops import embedding_ops as contrib_embedding_ops +from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op +from tensorflow.python.framework import dtypes + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops + +__all__ = ["safe_embedding_lookup_sparse", "hashed_embedding_lookup", + "hashed_embedding_lookup_sparse"] + + +# TODO(chapelle): move the safe_embedding_lookup_sparse code here (b/29826543) +safe_embedding_lookup_sparse = contrib_embedding_ops.safe_embedding_lookup_sparse # pylint: disable=line-too-long + + +def hashed_embedding_lookup(params, values, dimension, name=None): + """Looks up embeddings using parameter hashing for each value in `values`. + + The i-th embedding component of a value v in `values` is found by retrieving + the weight whose index is a fingerprint of the pair (v,i). + The concept is explored as "feature hashing" for model compression in this + paper: http://arxiv.org/pdf/1504.04788.pdf + + Feature hashing has the pleasant effect of allowing us to compute an embedding + without needing a pre-determined vocabulary, relieving some amount of process + complexity. It also allows for us to maintain embeddings for possibly + trillions of features with a fixed amount of memory. + + Note that this is superior to out-of-vocabulary shared "hash buckets" in that + the embedding is extremely likely to be unique for each token as opposed to + being shared across probably-colliding tokens. The price is that we must + compute a hash once for each scalar in the token's embedding as opposed to + once per token. + + If `params` is a list, it represents a partition of the embedding parameters. + Each tensor in the list should have the same length, except for the first ones + which may have an additional element. For instance 10 parameters can be + partitioned in 4 tensors with length `[3, 3, 2, 2]`. + + Args: + params: A `Tensor` or `list` of `Tensors`. + Each tensor must be of rank 1 with fully-defined shape. + values: `Tensor` of values to be embedded. + dimension: Embedding dimension + name: An optional name for this op. + + Returns: + A tensor with shape [d0, ..., dn, dimension] + with shape(values) = [d0, ..., dn] + + Raises: + ValueError: if dimension is not positive or the partition size is invalid. + """ + if not isinstance(params, list): + params = [params] + + with ops.op_scope(params + [dimension, values], name, + "hashed_embedding_lookup"): + if dimension <= 0: + raise ValueError("Dimension should be >0 not %d" % dimension) + + num_partitions = len(params) + partition_sizes = [] + for p in range(num_partitions): + shape = params[p].get_shape() + shape.assert_has_rank(1) + shape.assert_is_fully_defined() + partition_sizes.append(shape[0].value) + num_params = sum(partition_sizes) # Total number of parameters. + + # Assert the size of each partition. + for p in range(num_partitions): + expected_size = (num_params - p - 1) // num_partitions + 1 + if partition_sizes[p] != expected_size: + raise ValueError("Tensor %d in params has size %d, expected %d." % + (p, partition_sizes[p], expected_size)) + + # Flatten the values + values_shape = array_ops.shape(values) + values = array_ops.reshape(values, [-1, 1]) + + # With two values v1 and v2 and 3 dimensions, we will cross + # [[0, 1, 2], [0, 1, 2]] with [[v1], [v2]]. + tensors_to_cross = [array_ops.tile(array_ops.expand_dims( + math_ops.range(0, dimension), 0), array_ops.shape(values)), values] + ids = sparse_feature_cross_op.sparse_feature_cross( + tensors_to_cross, hashed_output=True, num_buckets=num_params) + ids = sparse_ops.sparse_tensor_to_dense(ids) + + # No need to validate the indices since we have checked the params + # dimensions and we know the largest id. + result = embedding_ops.embedding_lookup( + params, ids, partition_strategy="div", validate_indices=False) + + return array_ops.reshape(result, array_ops.concat( + 0, [values_shape, [dimension]])) + + +def hashed_embedding_lookup_sparse(params, + sparse_values, + dimension, + combiner="mean", + default_value=None, + name=None): + """Looks up embeddings of a sparse feature using parameter hashing. + + See `tf.contrib.layers.hashed_embedding_lookup` for embedding with hashing. + + Args: + params: A `Tensor` or `list` of `Tensors`. + Each tensor must be of rank 1 with fully-defined shape. + sparse_values: A 2-D `SparseTensor` containing the values to be embedded. + Some rows may be empty. + dimension: Embedding dimension + combiner: A string specifying how to combine embedding results for each + entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" + the default. + default_value: The value to use for an entry with no features. + name: An optional name for this op. + + Returns: + Dense tensor with shape [N, dimension] with N the number of rows in + sparse_values. + + Raises: + TypeError: If sparse_values is not a SparseTensor. + ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}. + """ + + if not isinstance(params, list): + params = [params] + if not isinstance(sparse_values, ops.SparseTensor): + raise TypeError("sparse_values must be SparseTensor") + + with ops.op_scope(params + [sparse_values], name, + "hashed_sparse_embedding_lookup") as scope: + # Fill in the empty rows. + if default_value is None: + # Random default values to reduce the risk of collision. + if sparse_values.dtype == dtypes.string: + default_value = "6ZxWzWOHxZ" + else: + default_value = 1288896567 + sparse_values, _ = sparse_ops.sparse_fill_empty_rows( + sparse_values, default_value) + + segment_ids = sparse_values.indices[:, 0] + if segment_ids.dtype != dtypes.int32: + segment_ids = math_ops.cast(segment_ids, dtypes.int32) + + values = sparse_values.values + values, idx = array_ops.unique(values) + + embeddings = hashed_embedding_lookup(params, values, dimension) + + if combiner == "sum": + embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids, + name=scope) + elif combiner == "mean": + embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids, + name=scope) + elif combiner == "sqrtn": + embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, segment_ids, + name=scope) + else: + raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.") + + return embeddings diff --git a/tensorflow/contrib/framework/python/ops/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py similarity index 63% rename from tensorflow/contrib/framework/python/ops/embedding_ops_test.py rename to tensorflow/contrib/layers/python/layers/embedding_ops_test.py index 78d4dabdc4e..8d2103207ec 100644 --- a/tensorflow/contrib/framework/python/ops/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -100,7 +100,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase): sparse_ids, sparse_weights = self._ids_and_weights_2d() embedding_lookup_result = ( - tf.contrib.framework.safe_embedding_lookup_sparse( + tf.contrib.layers.safe_embedding_lookup_sparse( embedding_weights, sparse_ids, sparse_weights).eval()) self.assertAllClose(embedding_lookup_result, [ @@ -114,7 +114,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase): sparse_ids, sparse_weights = self._ids_and_weights_2d() embedding_lookup_result = ( - tf.contrib.framework.safe_embedding_lookup_sparse( + tf.contrib.layers.safe_embedding_lookup_sparse( embedding_weights, sparse_ids, sparse_weights, @@ -132,9 +132,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase): sparse_ids, _ = self._ids_and_weights_2d() embedding_lookup_result = ( - tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights, - sparse_ids, - None).eval()) + tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights, + sparse_ids, + None).eval()) self.assertAllClose(embedding_lookup_result, [ (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, @@ -148,9 +148,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase): sparse_ids, _ = self._ids_and_weights_2d() embedding_lookup_result = ( - tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights, - sparse_ids, - None).eval()) + tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights, + sparse_ids, + None).eval()) embedding_weights = list(itertools.chain(*embedding_weights)) self.assertAllClose(embedding_lookup_result, [ @@ -166,13 +166,13 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase): embedding_weights[1] = embedding_weights[1].astype(np.float64) self.assertRaises(ValueError, - tf.contrib.framework.safe_embedding_lookup_sparse, + tf.contrib.layers.safe_embedding_lookup_sparse, embedding_weights, sparse_ids) embedding_weights = [ tf.constant(w, dtype=tf.float64) for w in embedding_weights ] self.assertRaises(ValueError, - tf.contrib.framework.safe_embedding_lookup_sparse, + tf.contrib.layers.safe_embedding_lookup_sparse, embedding_weights, sparse_ids, sparse_weights) def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self): @@ -181,7 +181,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase): sparse_ids, sparse_weights = self._ids_and_weights_3d() embedding_lookup_result = ( - tf.contrib.framework.safe_embedding_lookup_sparse( + tf.contrib.layers.safe_embedding_lookup_sparse( embedding_weights, sparse_ids, sparse_weights).eval()) self.assertAllClose(embedding_lookup_result, [ @@ -195,7 +195,7 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase): sparse_ids, sparse_weights = self._ids_and_weights_3d() embedding_lookup_result = ( - tf.contrib.framework.safe_embedding_lookup_sparse( + tf.contrib.layers.safe_embedding_lookup_sparse( embedding_weights, sparse_ids, sparse_weights, @@ -214,9 +214,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase): sparse_ids, _ = self._ids_and_weights_3d() embedding_lookup_result = ( - tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights, - sparse_ids, - None).eval()) + tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights, + sparse_ids, + None).eval()) self.assertAllClose(embedding_lookup_result, [ [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, @@ -231,9 +231,9 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase): sparse_ids, _ = self._ids_and_weights_3d() embedding_lookup_result = ( - tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights, - sparse_ids, - None).eval()) + tf.contrib.layers.safe_embedding_lookup_sparse(embedding_weights, + sparse_ids, + None).eval()) embedding_weights = list(itertools.chain(*embedding_weights)) self.assertAllClose(embedding_lookup_result, [ @@ -251,15 +251,110 @@ class SafeEmbeddingLookupSparseTest(tf.test.TestCase): embedding_weights[1] = embedding_weights[1].astype(np.float64) self.assertRaises(ValueError, - tf.contrib.framework.safe_embedding_lookup_sparse, + tf.contrib.layers.safe_embedding_lookup_sparse, embedding_weights, sparse_ids) embedding_weights = [ tf.constant(w, dtype=tf.float64) for w in embedding_weights ] self.assertRaises(ValueError, - tf.contrib.framework.safe_embedding_lookup_sparse, + tf.contrib.layers.safe_embedding_lookup_sparse, embedding_weights, sparse_ids, sparse_weights) +class HashedEmbeddingLookupTest(tf.test.TestCase): + + def setUp(self): + tf.set_random_seed(1) + + def _random_weights(self, size=50, num_shards=1): + assert size > 0 + assert num_shards > 0 + assert num_shards <= size + + embedding_weights = tf.create_partitioned_variables( + shape=[size], + slicing=[num_shards], + initializer=tf.truncated_normal_initializer(mean=0.0, + stddev=1.0, + dtype=tf.float32)) + for w in embedding_weights: + w.initializer.run() + return embedding_weights + + def test_hashed_embedding_consistency(self): + with self.test_session(): + embedding_weights = self._random_weights() + values = tf.constant(["foo", "foo"]) + + embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup( + embedding_weights, values, dimension=10).eval() + + self.assertAllEqual(embedding_lookup_result.shape, [2, 10]) + self.assertAllEqual(embedding_lookup_result[0], + embedding_lookup_result[1]) + + def test_hashed_embedding_multiple_partition(self): + with self.test_session(): + embedding_weights = self._random_weights(num_shards=7) + values = tf.constant([4, 4, 5]) + + embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup( + embedding_weights, values, dimension=5).eval() + + self.assertAllEqual(embedding_lookup_result.shape, [3, 5]) + self.assertAllEqual(embedding_lookup_result[0], + embedding_lookup_result[1]) + # Different embedding expected for different value. + embedding_diff = np.min((embedding_lookup_result[2] - + embedding_lookup_result[0]) ** 2) + self.assertGreater(embedding_diff, 0) + + def test_hashed_embedding_coverage(self): + with self.test_session(): + size = 8 + embedding_weights = self._random_weights(size=size, num_shards=3) + values = tf.constant(["foo"]) + + # Large embedding dimension to cover the full range of weights. + embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup( + embedding_weights, values, dimension=100).eval() + + self.assertEqual(len(np.unique(embedding_lookup_result[0])), size) + + def test_hashed_embedding_multi_dimension(self): + with self.test_session(): + embedding_weights = self._random_weights() + values = tf.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]]) + + embedding_lookup_result = tf.contrib.layers.hashed_embedding_lookup( + embedding_weights, values, dimension=10).eval() + + self.assertAllEqual(embedding_lookup_result.shape, [2, 3, 10]) + self.assertAllEqual(embedding_lookup_result[0][0], + embedding_lookup_result[1][2]) + + def test_hashed_embedding_lookup_sparse(self): + with self.test_session(): + embedding_weights = self._random_weights(num_shards=3) + sparse_tensor = tf.SparseTensor(values=["foo", "bar", "foo", "bar"], + indices=[[0, 0], [1, 0], [1, 1], [3, 0]], + shape=[5, 2]) + + embedding_lookup_result = ( + tf.contrib.layers.hashed_embedding_lookup_sparse( + embedding_weights, sparse_tensor, dimension=5, combiner="mean") + .eval()) + + self.assertAllEqual(embedding_lookup_result.shape, [5, 5]) + # Same non-zero embedding for the empty rows filled with a default value. + self.assertAllEqual(embedding_lookup_result[2], + embedding_lookup_result[4]) + embedding_norm = np.sum(embedding_lookup_result[2] ** 2) + self.assertGreater(embedding_norm, 0) + + self.assertAllEqual(embedding_lookup_result[1], + 0.5 * (embedding_lookup_result[0] + + embedding_lookup_result[3])) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index c05aa6c2917..ba19d773555 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -75,7 +75,7 @@ import abc import collections import math -from tensorflow.contrib.framework.python.ops import embedding_ops as contrib_embedding_ops +from tensorflow.contrib.layers.python.layers import embedding_ops from tensorflow.contrib.layers.python.ops import bucketization_op from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op from tensorflow.contrib.lookup import lookup_ops as contrib_lookup_ops @@ -517,7 +517,7 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple( # This is effectively the same format as str(self), except with our special # treatment. - return "_EmbeddingColumn(%s)" % ", ".join(fields_values) + return "%s(%s)" % (type(self).__name__, ", ".join(fields_values)) def insert_transformed_feature(self, columns_to_tensors): self.sparse_id_column.insert_transformed_feature(columns_to_tensors) @@ -576,6 +576,105 @@ def embedding_column(sparse_id_column, return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer) +class _HashedEmbeddingColumn(collections.namedtuple( + "_HashedEmbeddingColumn", ["column_name", "size", "dimension", "combiner", + "initializer"]), _EmbeddingColumn): + """See `hashed_embedding_column`.""" + + def __new__(cls, + column_name, + size, + dimension, + combiner="mean", + initializer=None): + if initializer is not None and not callable(initializer): + raise ValueError("initializer must be callable if specified.") + if initializer is None: + stddev = 0.1 + # TODO(b/25671353): Better initial value? + initializer = init_ops.truncated_normal_initializer(mean=0.0, + stddev=stddev) + return super(_HashedEmbeddingColumn, cls).__new__(cls, column_name, size, + dimension, combiner, + initializer) + + @property + def name(self): + return self.column_name + "_embedding" + + @property + def config(self): + return {self.column_name: parsing_ops.VarLenFeature(dtypes.string)} + + def insert_transformed_feature(self, columns_to_tensors): + columns_to_tensors[self] = columns_to_tensors[self.column_name] + + def to_dnn_input_layer(self, + input_tensor, + weight_collections=None, + trainable=True): + # Same heuristic for the number of shards as _max_size_embedding_partitioner + max_shard_bytes = (64 << 20) - 1 + shards = self.size * 4.0 / max_shard_bytes + shards = max(1, int(math.ceil(shards))) + + embeddings = partitioned_variables.create_partitioned_variables( + shape=[self.size], + slicing=[shards], + initializer=self.initializer, + dtype=dtypes.float32, + collections=_add_variable_collection(weight_collections), + name=self.name + "_weights", + reuse=False, + trainable=trainable) + + return embedding_ops.hashed_embedding_lookup_sparse( + embeddings, input_tensor, self.dimension, name=self.name + "_lookup") + + +def hashed_embedding_column(column_name, + size, + dimension, + combiner="mean", + initializer=None): + """Creates an embedding column of a sparse feature using parameter hashing. + + The i-th embedding component of a value v is found by retrieving an + embedding weight whose index is a fingerprint of the pair (v,i). + + Args: + column_name: A string defining sparse column name. + size: An integer specifying the number of parameters in the embedding layer. + dimension: An integer specifying dimension of the embedding. + combiner: A string specifying how to reduce if there are multiple entries + in a single row. Currently "mean", "sqrtn" and "sum" are supported. Each + of this can be thought as example level normalizations on the column: + * "sum": do not normalize features in the column + * "mean": do l1 normalization on features in the column + * "sqrtn": do l2 normalization on features in the column + For more information: `tf.embedding_lookup_sparse`. + initializer: A variable initializer function to be used in embedding + variable initialization. If not specified, defaults to + `tf.truncated_normal_initializer` with mean 0 and standard deviation 0.1. + + Returns: + A _HashedEmbeddingColumn. + + Raises: + ValueError: if dimension or size is not a positive integer; or if combiner + is not supported. + + """ + if (dimension < 1) or (size < 1): + raise ValueError("Dimension and size must be greater than 0.") + + if combiner not in ("mean", "sqrtn", "sum"): + raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.") + + return _HashedEmbeddingColumn(column_name, size, dimension, combiner, + initializer) + + class _RealValuedColumn(_FeatureColumn, collections.namedtuple( "_RealValuedColumn", ["column_name", "dimension", "default_value", "dtype"])): @@ -1237,7 +1336,7 @@ def _create_embedding_lookup(input_tensor, vocab_size, dimension, reuse=False, trainable=trainable) - return contrib_embedding_ops.safe_embedding_lookup_sparse( + return embedding_ops.safe_embedding_lookup_sparse( embeddings, input_tensor, default_id=0, diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 60516ddae59..1d0f45357ed 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -282,6 +282,27 @@ class InputLayerTest(tf.test.TestCase): tf.initialize_all_variables().run() self.assertAllEqual(output.eval().shape, [2, 10]) + def testHashedEmbeddingColumn(self): + wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo", "omar"], + indices=[[0, 0], [1, 0], [1, 1], [2, 0]], + shape=[3, 2]) + + features = {"wire": wire_tensor} + # Big enough hash space so that hopefully there is no collision + embedded_sparse = tf.contrib.layers.hashed_embedding_column("wire", 1000, 3) + output = tf.contrib.layers.input_from_feature_columns( + features, [embedded_sparse], weight_collections=["my_collection"]) + weights = tf.get_collection("my_collection") + grad = tf.gradients(output, weights) + with self.test_session(): + tf.initialize_all_variables().run() + gradient_values = [] + # Collect the gradient from the different partitions (one in this test) + for p in range(len(grad)): + gradient_values.extend(grad[p].values.eval()) + gradient_values.sort() + self.assertAllEqual(gradient_values, [0.5]*6 + [2]*3) + def testEmbeddingColumnWithInitializer(self): hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10) wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],