From c9d03a568a221e96c47ee7d5be703984d61b95a4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Jul 2017 12:44:53 -0700 Subject: [PATCH] Add tf.contrib.nn.rank_sampled_softmax_loss, a variant of tf.nn.sampled_softmax_loss that has been shown to improve rank loss. Paper: https://arxiv.org/abs/1707.03073 PiperOrigin-RevId: 161702455 --- tensorflow/contrib/nn/BUILD | 21 ++ .../contrib/nn/python/ops/sampling_ops.py | 243 +++++++++++++ .../nn/python/ops/sampling_ops_test.py | 322 ++++++++++++++++++ 3 files changed, 586 insertions(+) create mode 100644 tensorflow/contrib/nn/python/ops/sampling_ops.py create mode 100644 tensorflow/contrib/nn/python/ops/sampling_ops_test.py diff --git a/tensorflow/contrib/nn/BUILD b/tensorflow/contrib/nn/BUILD index dbac049d833..af33496e5d7 100644 --- a/tensorflow/contrib/nn/BUILD +++ b/tensorflow/contrib/nn/BUILD @@ -7,6 +7,8 @@ exports_files(["LICENSE"]) package(default_visibility = ["//visibility:public"]) +load("//tensorflow:tensorflow.bzl", "py_test") + py_library( name = "nn_py", srcs = [ @@ -14,15 +16,34 @@ py_library( "python/__init__.py", "python/ops/__init__.py", "python/ops/cross_entropy.py", + "python/ops/sampling_ops.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:nn", "//tensorflow/python:util", ], ) +py_test( + name = "sampling_ops_test", + size = "small", + srcs = ["python/ops/sampling_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":nn_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:nn", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/nn/python/ops/sampling_ops.py b/tensorflow/contrib/nn/python/ops/sampling_ops.py new file mode 100644 index 00000000000..7a9eed511bd --- /dev/null +++ b/tensorflow/contrib/nn/python/ops/sampling_ops.py @@ -0,0 +1,243 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops related to candidate sampling.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn + + +def _rank_resample(weights, biases, inputs, sampled_values, num_resampled, + resampling_temperature, partition_strategy): + """A helper function for rank_sampled_softmax_loss. + + This computes, for each i in `sampled_values`, + + log(sum_j exp((w_i * x_j + b_i) / resampling_temperature)) + + where w_i, b_i are the weight and bias of the i-th class, repsectively, + and j ranges over the rows of `inputs`. For efficiency, we rearrange the + computation to + + log(sum_j exp(w_i * (x_j / resampling_temperature))) + + b_i / resampling_temperature. + + This translates to the following batched computation using tensorflow ops: + + reduce_logsumexp(matmul(embeddings, + transpose(inputs / resampling_temperature))) + + biases / resampling_temperature + + The computation of the first term is colocated with the embeddings using + `transform_fn` in `embedding_ops._embedding_lookup_and_transform`. The second + term, not the bottleneck, is computed at the worker. + + Args: + weights: From `rank_sampled_softmax_loss`. + biases: From `rank_sampled_softmax_loss`. + inputs: From `rank_sampled_softmax_loss`. + sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`, + `sampled_expected_count`) returned by a `*_candidate_sampler` function. + num_resampled: An `int`. This many values are selected from + `sampled_values` using the adaptive resampling algorithm. The caller + must ensure that `num_resampled` is less than the size of + `sampled_values`. + resampling_temperature: A scalar `Tensor` with the temperature parameter + for the adaptive resampling algorithm. + partition_strategy: From `rank_sampled_softmax_loss`. + + Returns: + A tuple of (`resampled_candidates`, `true_expected_count`, + `resampled_expected_count`), similar to `sampled_values` but sampled + down to `num_resampled` values. + """ + # This code supports passing a Tensor for num_resampled, but since it is only + # called with an int, that's what we specify in the arg list. If this + # function is ever externalized, we should change the doc to support Tensor. + + sampled, true_expected_count, sampled_expected_count = sampled_values + + sampled = math_ops.cast(array_ops.stop_gradient(sampled), dtypes.int64) + true_expected_count = array_ops.stop_gradient(true_expected_count) + sampled_expected_count = array_ops.stop_gradient(sampled_expected_count) + + reweighted_inputs = inputs / resampling_temperature + + def logsumexp_logit(embeddings): + return math_ops.reduce_logsumexp( + math_ops.matmul(embeddings, reweighted_inputs, transpose_b=True), + axis=1, + keep_dims=False) + + # Calling this protected form of embedding_lookup allows co-locating + # the logsumexp computation with the partitioned weights, which yields + # a large speedup in practice. + sampled_logits = embedding_ops._embedding_lookup_and_transform( # pylint: disable=protected-access + weights, sampled, partition_strategy, transform_fn=logsumexp_logit) + sampled_b = array_ops.reshape( + embedding_ops.embedding_lookup(biases, sampled, partition_strategy), [-1]) + sampled_logits += sampled_b / resampling_temperature + + _, resampled_indices = nn.top_k(sampled_logits, k=num_resampled, sorted=False) + resampled = array_ops.gather(sampled, indices=resampled_indices) + resampled_expected_count = array_ops.gather( + sampled_expected_count, indices=resampled_indices) + + return resampled, true_expected_count, resampled_expected_count + + +# TODO(ccolby): Before checkin, Add reference to TAPAS paper when in arxiv.org. +def rank_sampled_softmax_loss(weights, + biases, + labels, + inputs, + num_sampled, + num_resampled, + num_classes, + num_true, + sampled_values, + resampling_temperature, + remove_accidental_hits, + partition_strategy, + name=None): + """Computes softmax loss using rank-based adaptive resampling. + + This has been shown to improve rank loss after training compared to + @{tf.nn.sampled_softmax_loss}. For a description of the algorithm and some + experimental results, please see: [TAPAS: Two-pass Approximate Adaptive + Sampling for Softmax](https://arxiv.org/abs/1707.03073). + + Sampling follows two phases: + * In the first phase, `num_sampled` classes are selected using + @{tf.nn.learned_unigram_candidate_sampler} or supplied `sampled_values`. + The logits are calculated on those sampled classes. This phases is + similar to @{tf.nn.sampled_softmax_loss}. + * In the second phase, the `num_resampled` classes with highest predicted + probability are kept. Probabilities are + `LogSumExp(logits / resampling_temperature)`, where the sum is over + `inputs`. + + The `resampling_temperature` parameter controls the "adaptiveness" of the + resampling. At lower temperatures, resampling is more adaptive because it + picks more candidates close to the predicted classes. A common strategy is + to decrease the temperature as training proceeds. + + See @{tf.nn.sampled_softmax_loss} for more documentation on sampling and + for typical default values for some of the parameters. + + This operation is for training only. It is generally an underestimate of + the full softmax loss. + + A common use case is to use this method for training, and calculate the full + softmax loss for evaluation or inference. In this case, you must set + `partition_strategy="div"` for the two losses to be consistent, as in the + following example: + + ```python + if mode == "train": + loss = rank_sampled_softmax_loss( + weights=weights, + biases=biases, + labels=labels, + inputs=inputs, + ..., + partition_strategy="div") + elif mode == "eval": + logits = tf.matmul(inputs, tf.transpose(weights)) + logits = tf.nn.bias_add(logits, biases) + labels_one_hot = tf.one_hot(labels, n_classes) + loss = tf.nn.softmax_cross_entropy_with_logits( + labels=labels_one_hot, + logits=logits) + ``` + + Args: + weights: A `Tensor` or `PartitionedVariable` of shape `[num_classes, dim]`, + or a list of `Tensor` objects whose concatenation along dimension 0 + has shape [num_classes, dim]. The (possibly-sharded) class embeddings. + biases: A `Tensor` or `PartitionedVariable` of shape `[num_classes]`. + The (possibly-sharded) class biases. + labels: A `Tensor` of type `int64` and shape `[batch_size, + num_true]`. The target classes. Note that this format differs from + the `labels` argument of `nn.softmax_cross_entropy_with_logits`. + inputs: A `Tensor` of shape `[batch_size, dim]`. The forward + activations of the input network. + num_sampled: An `int`. The number of classes to randomly sample per batch. + num_resampled: An `int`. The number of classes to select from the + `num_sampled` classes using the adaptive resampling algorithm. Must be + less than `num_sampled`. + num_classes: An `int`. The number of possible classes. + num_true: An `int`. The number of target classes per training example. + sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`, + `sampled_expected_count`) returned by a `*_candidate_sampler` function. + If None, default to `nn.learned_unigram_candidate_sampler`. + resampling_temperature: A scalar `Tensor` with the temperature parameter + for the adaptive resampling algorithm. + remove_accidental_hits: A `bool`. Whether to remove "accidental hits" + where a sampled class equals one of the target classes. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. + See @{tf.nn.embedding_lookup} for more details. + name: A name for the operation (optional). + + Returns: + A `batch_size` 1-D tensor of per-example sampled softmax losses. + + Raises: + ValueError: If `num_sampled <= num_resampled`. + """ + if num_sampled > num_classes: + raise ValueError("num_sampled ({}) cannot be greater than num_classes ({})". + format(num_sampled, num_classes)) + if num_sampled <= num_resampled: + raise ValueError("num_resampled ({}) must be less than num_sampled ({})". + format(num_resampled, num_sampled)) + if partition_strategy not in ("div", "mod"): + raise ValueError( + "unsupported partition_strategy ({})".format(partition_strategy)) + with ops.name_scope(name, "rank_sampled_softmax_loss", [ + weights, biases, labels, inputs, sampled_values, resampling_temperature + ]) as name: + if not sampled_values: + sampled_values = nn.learned_unigram_candidate_sampler( + true_classes=labels, + num_true=num_true, + num_sampled=num_sampled, + unique=True, + range_max=num_classes) + # From sampled_values, select the top num_resampled values using the + # adaptive rank resampling strategy. + resampled_values = _rank_resample(weights, biases, inputs, sampled_values, + num_resampled, resampling_temperature, + partition_strategy) + return nn.sampled_softmax_loss( + weights=weights, + biases=biases, + labels=labels, + inputs=inputs, + num_sampled=num_resampled, + num_classes=num_classes, + num_true=num_true, + sampled_values=resampled_values, + remove_accidental_hits=remove_accidental_hits, + partition_strategy=partition_strategy, + name=name) diff --git a/tensorflow/contrib/nn/python/ops/sampling_ops_test.py b/tensorflow/contrib/nn/python/ops/sampling_ops_test.py new file mode 100644 index 00000000000..1d4fe1321b8 --- /dev/null +++ b/tensorflow/contrib/nn/python/ops/sampling_ops_test.py @@ -0,0 +1,322 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sampling_ops.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.nn.python.ops import sampling_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class RankSampledSoftmaxLossTest(test.TestCase): + + def setUp(self): + self._sampled = [3, 4, 5, 6, 7] + self._num_sampled = len(self._sampled) + # Because values of all matrices increase with indices, logits increase with + # class id. So, for the above sampled classes, adaptive sampling will select + # these resampled classes. + self._resampled = [5, 6, 7] + self._num_resampled = len(self._resampled) + self._num_classes = 10 + self._num_true = 2 + self._sampled_values = (self._sampled, [[0.5], [0.5]], + [0.5, 0.5, 0.5, 0.5, 0.5]) + self._resampled_values = (self._resampled, [[0.5], [0.5]], [0.5, 0.5, 0.5]) + self._remove_accidental_hits = False + self._embed_dim = 5 + self._batch_size = 2 + + def _weights(self): + return constant_op.constant([ + [0.0, 0.1, 0.2, 0.3, 0.4], + [1.0, 1.1, 1.2, 1.3, 1.4], + [2.0, 2.1, 2.2, 2.3, 2.4], + [3.0, 3.1, 3.2, 3.3, 3.4], + [4.0, 4.1, 4.2, 4.3, 4.4], + [5.0, 5.1, 5.2, 5.3, 5.4], + [6.0, 6.1, 6.2, 6.3, 6.4], + [7.0, 7.1, 7.2, 7.3, 7.4], + [8.0, 8.1, 8.2, 8.3, 8.4], + [9.0, 9.1, 9.2, 9.3, 9.4], + ]) + + def _div_sharded_weights(self): + return [ + constant_op.constant([ + [0.0, 0.1, 0.2, 0.3, 0.4], + [1.0, 1.1, 1.2, 1.3, 1.4], + ]), + constant_op.constant([ + [2.0, 2.1, 2.2, 2.3, 2.4], + [3.0, 3.1, 3.2, 3.3, 3.4], + ]), + constant_op.constant([ + [4.0, 4.1, 4.2, 4.3, 4.4], + [5.0, 5.1, 5.2, 5.3, 5.4], + ]), + constant_op.constant([ + [6.0, 6.1, 6.2, 6.3, 6.4], + [7.0, 7.1, 7.2, 7.3, 7.4], + ]), + constant_op.constant([ + [8.0, 8.1, 8.2, 8.3, 8.4], + [9.0, 9.1, 9.2, 9.3, 9.4], + ]), + ] + + def _mod_sharded_weights(self): + return [ + constant_op.constant([ + [0.0, 0.1, 0.2, 0.3, 0.4], + [5.0, 5.1, 5.2, 5.3, 5.4], + ]), + constant_op.constant([ + [1.0, 1.1, 1.2, 1.3, 1.4], + [6.0, 6.1, 6.2, 6.3, 6.4], + ]), + constant_op.constant([ + [2.0, 2.1, 2.2, 2.3, 2.4], + [7.0, 7.1, 7.2, 7.3, 7.4], + ]), + constant_op.constant([ + [3.0, 3.1, 3.2, 3.3, 3.4], + [8.0, 8.1, 8.2, 8.3, 8.4], + ]), + constant_op.constant([ + [4.0, 4.1, 4.2, 4.3, 4.4], + [9.0, 9.1, 9.2, 9.3, 9.4], + ]), + ] + + def _biases(self): + return constant_op.constant( + [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) + + def _div_sharded_biases(self): + return [ + constant_op.constant([0.0, 0.1]), + constant_op.constant([0.2, 0.3]), + constant_op.constant([0.4, 0.5]), + constant_op.constant([0.6, 0.7]), + constant_op.constant([0.8, 0.9]), + ] + + def _mod_sharded_biases(self): + return [ + constant_op.constant([0.0, 0.5]), + constant_op.constant([0.1, 0.6]), + constant_op.constant([0.2, 0.7]), + constant_op.constant([0.3, 0.8]), + constant_op.constant([0.4, 0.9]), + ] + + def _labels(self): + return constant_op.constant( + [[0, 1], [1, 2]], + shape=(self._batch_size, self._num_true), + name='labels', + dtype=dtypes.int64) + + def _inputs(self): + return constant_op.constant( + [ + [0., 1., 2., 3., 4.], + [10., 11., 12., 13., 14.], + ], + shape=(self._batch_size, self._embed_dim), + name='inputs') + + def testInvalidNumSampled0(self): + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, + r'num_resampled \(3\) must be less than num_sampled \(3\)'): + sampling_ops.rank_sampled_softmax_loss( + weights=self._weights(), + biases=self._biases(), + labels=self._labels(), + inputs=self._inputs(), + num_sampled=3, + num_resampled=3, + num_classes=self._num_classes, + num_true=self._num_true, + sampled_values=None, + resampling_temperature=1., + remove_accidental_hits=True, + partition_strategy='div') + + def testInvalidNumSampled1(self): + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, + r'num_resampled \(3\) must be less than num_sampled \(2\)'): + sampling_ops.rank_sampled_softmax_loss( + weights=self._weights(), + biases=self._biases(), + labels=self._labels(), + inputs=self._inputs(), + num_sampled=2, + num_resampled=3, + num_classes=self._num_classes, + num_true=self._num_true, + sampled_values=None, + resampling_temperature=1., + remove_accidental_hits=True, + partition_strategy='div') + + def testMissingPartitionStrategy(self): + with ops.Graph().as_default(): + with self.assertRaisesRegexp(ValueError, + r'unsupported partition_strategy \(None\)'): + sampling_ops.rank_sampled_softmax_loss( + weights=self._weights(), + biases=self._biases(), + labels=self._labels(), + inputs=self._inputs(), + num_sampled=2, + num_resampled=1, + num_classes=self._num_classes, + num_true=self._num_true, + sampled_values=None, + resampling_temperature=1., + remove_accidental_hits=True, + partition_strategy=None) + + def _testCompareWithNN(self, weights, biases, partition_strategy): + with ops.Graph().as_default(): + loss = sampling_ops.rank_sampled_softmax_loss( + weights=weights(), + biases=biases(), + labels=self._labels(), + inputs=self._inputs(), + num_sampled=self._num_sampled, + num_resampled=self._num_resampled, + num_classes=self._num_classes, + num_true=self._num_true, + sampled_values=self._sampled_values, + resampling_temperature=1., + remove_accidental_hits=self._remove_accidental_hits, + partition_strategy=partition_strategy) + loss_nn = nn.sampled_softmax_loss( + weights=weights(), + biases=biases(), + labels=self._labels(), + inputs=self._inputs(), + num_sampled=self._num_resampled, + num_classes=self._num_classes, + num_true=self._num_true, + sampled_values=self._resampled_values, + remove_accidental_hits=self._remove_accidental_hits, + partition_strategy=partition_strategy) + with self.test_session() as sess: + loss_val = sess.run(loss) + loss_nn_val = sess.run(loss_nn) + + self.assertAllClose(loss_val, loss_nn_val) + + def testCompareWithNNUnsharded(self): + self._testCompareWithNN(self._weights, self._biases, 'div') + + def testCompareWithNNShardWeightsDiv(self): + self._testCompareWithNN(self._div_sharded_weights, self._biases, 'div') + + def testCompareWithNNShardWeightsAndBiasesDiv(self): + self._testCompareWithNN(self._div_sharded_weights, self._div_sharded_biases, + 'div') + + def testCompareWithNNShardWeightsMod(self): + self._testCompareWithNN(self._mod_sharded_weights, self._biases, 'mod') + + def testCompareWithNNShardWeightsAndBiasesMod(self): + self._testCompareWithNN(self._mod_sharded_weights, self._mod_sharded_biases, + 'mod') + + def _testCompareWithNNTemperature(self, temperature, resampled): + weights = [[1., 2.], [3., 4.]] # two sampled classes + inputs = [[6., -5. / 2.], [-11., 21. / 2.]] + # Let w0, w1 = weights of sampled classes (biases set to 0 for simplicity) + # Let x0, x1 = inputs + # logits: + # w0.x0 = 1 + # w0.x1 = 10 + # w1.x0 = 8 + # w1.x1 = 9 + # Resampling 1 class with temperature = t will pick the larger of: + # exp(1/t) + exp(10/t) ==> w0, for values of t < 2.12 + # exp(8/t) + exp(9/t) ==> w1, for values of t > 2.13 + num_sampled = 2 + num_resampled = 1 + num_classes = 2 + num_true = 1 + sampled_values = [0, 1], [[1.], [1.]], [1., 1.] + resampled_values = [resampled], [[1.], [1.]], [1.] + remove_accidental_hits = False + with ops.Graph().as_default(): + weights = constant_op.constant(weights) + biases = constant_op.constant([0., 0.]) + labels = constant_op.constant([[0], [1]], dtype=dtypes.int64) + inputs = constant_op.constant(inputs) + loss = sampling_ops.rank_sampled_softmax_loss( + weights=weights, + biases=biases, + labels=labels, + inputs=inputs, + num_sampled=num_sampled, + num_resampled=num_resampled, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_values, + resampling_temperature=constant_op.constant(temperature), + remove_accidental_hits=remove_accidental_hits, + partition_strategy='div') + loss_nn = nn.sampled_softmax_loss( + weights=weights, + biases=biases, + labels=labels, + inputs=inputs, + num_sampled=num_resampled, + num_classes=num_classes, + num_true=num_true, + sampled_values=resampled_values, + remove_accidental_hits=remove_accidental_hits, + partition_strategy='div') + with self.test_session() as sess: + loss_val = sess.run(loss) + loss_nn_val = sess.run(loss_nn) + + self.assertAllClose(loss_val, loss_nn_val) + + def testCompareWithNNTemperatureLo1(self): + self._testCompareWithNNTemperature(1., 0) + + def testCompareWithNNTemperatureLo2(self): + self._testCompareWithNNTemperature(2.12, 0) + + def testCompareWithNNTemperatureHi1(self): + self._testCompareWithNNTemperature(2.13, 1) + + def testCompareWithNNTemperatureHi2(self): + self._testCompareWithNNTemperature(3., 1) + + +if __name__ == '__main__': + test.main()