Add tf.contrib.nn.rank_sampled_softmax_loss, a variant of tf.nn.sampled_softmax_loss that has been shown to improve rank loss. Paper: https://arxiv.org/abs/1707.03073

PiperOrigin-RevId: 161702455
This commit is contained in:
A. Unique TensorFlower 2017-07-12 12:44:53 -07:00 committed by TensorFlower Gardener
parent 9aa0dcbf28
commit c9d03a568a
3 changed files with 586 additions and 0 deletions
tensorflow/contrib/nn

View File

@ -7,6 +7,8 @@ exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "nn_py",
srcs = [
@ -14,15 +16,34 @@ py_library(
"python/__init__.py",
"python/ops/__init__.py",
"python/ops/cross_entropy.py",
"python/ops/sampling_ops.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:util",
],
)
py_test(
name = "sampling_ops_test",
size = "small",
srcs = ["python/ops/sampling_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":nn_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:nn",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -0,0 +1,243 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops related to candidate sampling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
def _rank_resample(weights, biases, inputs, sampled_values, num_resampled,
resampling_temperature, partition_strategy):
"""A helper function for rank_sampled_softmax_loss.
This computes, for each i in `sampled_values`,
log(sum_j exp((w_i * x_j + b_i) / resampling_temperature))
where w_i, b_i are the weight and bias of the i-th class, repsectively,
and j ranges over the rows of `inputs`. For efficiency, we rearrange the
computation to
log(sum_j exp(w_i * (x_j / resampling_temperature))) +
b_i / resampling_temperature.
This translates to the following batched computation using tensorflow ops:
reduce_logsumexp(matmul(embeddings,
transpose(inputs / resampling_temperature))) +
biases / resampling_temperature
The computation of the first term is colocated with the embeddings using
`transform_fn` in `embedding_ops._embedding_lookup_and_transform`. The second
term, not the bottleneck, is computed at the worker.
Args:
weights: From `rank_sampled_softmax_loss`.
biases: From `rank_sampled_softmax_loss`.
inputs: From `rank_sampled_softmax_loss`.
sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`,
`sampled_expected_count`) returned by a `*_candidate_sampler` function.
num_resampled: An `int`. This many values are selected from
`sampled_values` using the adaptive resampling algorithm. The caller
must ensure that `num_resampled` is less than the size of
`sampled_values`.
resampling_temperature: A scalar `Tensor` with the temperature parameter
for the adaptive resampling algorithm.
partition_strategy: From `rank_sampled_softmax_loss`.
Returns:
A tuple of (`resampled_candidates`, `true_expected_count`,
`resampled_expected_count`), similar to `sampled_values` but sampled
down to `num_resampled` values.
"""
# This code supports passing a Tensor for num_resampled, but since it is only
# called with an int, that's what we specify in the arg list. If this
# function is ever externalized, we should change the doc to support Tensor.
sampled, true_expected_count, sampled_expected_count = sampled_values
sampled = math_ops.cast(array_ops.stop_gradient(sampled), dtypes.int64)
true_expected_count = array_ops.stop_gradient(true_expected_count)
sampled_expected_count = array_ops.stop_gradient(sampled_expected_count)
reweighted_inputs = inputs / resampling_temperature
def logsumexp_logit(embeddings):
return math_ops.reduce_logsumexp(
math_ops.matmul(embeddings, reweighted_inputs, transpose_b=True),
axis=1,
keep_dims=False)
# Calling this protected form of embedding_lookup allows co-locating
# the logsumexp computation with the partitioned weights, which yields
# a large speedup in practice.
sampled_logits = embedding_ops._embedding_lookup_and_transform( # pylint: disable=protected-access
weights, sampled, partition_strategy, transform_fn=logsumexp_logit)
sampled_b = array_ops.reshape(
embedding_ops.embedding_lookup(biases, sampled, partition_strategy), [-1])
sampled_logits += sampled_b / resampling_temperature
_, resampled_indices = nn.top_k(sampled_logits, k=num_resampled, sorted=False)
resampled = array_ops.gather(sampled, indices=resampled_indices)
resampled_expected_count = array_ops.gather(
sampled_expected_count, indices=resampled_indices)
return resampled, true_expected_count, resampled_expected_count
# TODO(ccolby): Before checkin, Add reference to TAPAS paper when in arxiv.org.
def rank_sampled_softmax_loss(weights,
biases,
labels,
inputs,
num_sampled,
num_resampled,
num_classes,
num_true,
sampled_values,
resampling_temperature,
remove_accidental_hits,
partition_strategy,
name=None):
"""Computes softmax loss using rank-based adaptive resampling.
This has been shown to improve rank loss after training compared to
@{tf.nn.sampled_softmax_loss}. For a description of the algorithm and some
experimental results, please see: [TAPAS: Two-pass Approximate Adaptive
Sampling for Softmax](https://arxiv.org/abs/1707.03073).
Sampling follows two phases:
* In the first phase, `num_sampled` classes are selected using
@{tf.nn.learned_unigram_candidate_sampler} or supplied `sampled_values`.
The logits are calculated on those sampled classes. This phases is
similar to @{tf.nn.sampled_softmax_loss}.
* In the second phase, the `num_resampled` classes with highest predicted
probability are kept. Probabilities are
`LogSumExp(logits / resampling_temperature)`, where the sum is over
`inputs`.
The `resampling_temperature` parameter controls the "adaptiveness" of the
resampling. At lower temperatures, resampling is more adaptive because it
picks more candidates close to the predicted classes. A common strategy is
to decrease the temperature as training proceeds.
See @{tf.nn.sampled_softmax_loss} for more documentation on sampling and
for typical default values for some of the parameters.
This operation is for training only. It is generally an underestimate of
the full softmax loss.
A common use case is to use this method for training, and calculate the full
softmax loss for evaluation or inference. In this case, you must set
`partition_strategy="div"` for the two losses to be consistent, as in the
following example:
```python
if mode == "train":
loss = rank_sampled_softmax_loss(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
...,
partition_strategy="div")
elif mode == "eval":
logits = tf.matmul(inputs, tf.transpose(weights))
logits = tf.nn.bias_add(logits, biases)
labels_one_hot = tf.one_hot(labels, n_classes)
loss = tf.nn.softmax_cross_entropy_with_logits(
labels=labels_one_hot,
logits=logits)
```
Args:
weights: A `Tensor` or `PartitionedVariable` of shape `[num_classes, dim]`,
or a list of `Tensor` objects whose concatenation along dimension 0
has shape [num_classes, dim]. The (possibly-sharded) class embeddings.
biases: A `Tensor` or `PartitionedVariable` of shape `[num_classes]`.
The (possibly-sharded) class biases.
labels: A `Tensor` of type `int64` and shape `[batch_size,
num_true]`. The target classes. Note that this format differs from
the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
num_sampled: An `int`. The number of classes to randomly sample per batch.
num_resampled: An `int`. The number of classes to select from the
`num_sampled` classes using the adaptive resampling algorithm. Must be
less than `num_sampled`.
num_classes: An `int`. The number of possible classes.
num_true: An `int`. The number of target classes per training example.
sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`,
`sampled_expected_count`) returned by a `*_candidate_sampler` function.
If None, default to `nn.learned_unigram_candidate_sampler`.
resampling_temperature: A scalar `Tensor` with the temperature parameter
for the adaptive resampling algorithm.
remove_accidental_hits: A `bool`. Whether to remove "accidental hits"
where a sampled class equals one of the target classes.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
See @{tf.nn.embedding_lookup} for more details.
name: A name for the operation (optional).
Returns:
A `batch_size` 1-D tensor of per-example sampled softmax losses.
Raises:
ValueError: If `num_sampled <= num_resampled`.
"""
if num_sampled > num_classes:
raise ValueError("num_sampled ({}) cannot be greater than num_classes ({})".
format(num_sampled, num_classes))
if num_sampled <= num_resampled:
raise ValueError("num_resampled ({}) must be less than num_sampled ({})".
format(num_resampled, num_sampled))
if partition_strategy not in ("div", "mod"):
raise ValueError(
"unsupported partition_strategy ({})".format(partition_strategy))
with ops.name_scope(name, "rank_sampled_softmax_loss", [
weights, biases, labels, inputs, sampled_values, resampling_temperature
]) as name:
if not sampled_values:
sampled_values = nn.learned_unigram_candidate_sampler(
true_classes=labels,
num_true=num_true,
num_sampled=num_sampled,
unique=True,
range_max=num_classes)
# From sampled_values, select the top num_resampled values using the
# adaptive rank resampling strategy.
resampled_values = _rank_resample(weights, biases, inputs, sampled_values,
num_resampled, resampling_temperature,
partition_strategy)
return nn.sampled_softmax_loss(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
num_sampled=num_resampled,
num_classes=num_classes,
num_true=num_true,
sampled_values=resampled_values,
remove_accidental_hits=remove_accidental_hits,
partition_strategy=partition_strategy,
name=name)

View File

@ -0,0 +1,322 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for sampling_ops.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.nn.python.ops import sampling_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
class RankSampledSoftmaxLossTest(test.TestCase):
def setUp(self):
self._sampled = [3, 4, 5, 6, 7]
self._num_sampled = len(self._sampled)
# Because values of all matrices increase with indices, logits increase with
# class id. So, for the above sampled classes, adaptive sampling will select
# these resampled classes.
self._resampled = [5, 6, 7]
self._num_resampled = len(self._resampled)
self._num_classes = 10
self._num_true = 2
self._sampled_values = (self._sampled, [[0.5], [0.5]],
[0.5, 0.5, 0.5, 0.5, 0.5])
self._resampled_values = (self._resampled, [[0.5], [0.5]], [0.5, 0.5, 0.5])
self._remove_accidental_hits = False
self._embed_dim = 5
self._batch_size = 2
def _weights(self):
return constant_op.constant([
[0.0, 0.1, 0.2, 0.3, 0.4],
[1.0, 1.1, 1.2, 1.3, 1.4],
[2.0, 2.1, 2.2, 2.3, 2.4],
[3.0, 3.1, 3.2, 3.3, 3.4],
[4.0, 4.1, 4.2, 4.3, 4.4],
[5.0, 5.1, 5.2, 5.3, 5.4],
[6.0, 6.1, 6.2, 6.3, 6.4],
[7.0, 7.1, 7.2, 7.3, 7.4],
[8.0, 8.1, 8.2, 8.3, 8.4],
[9.0, 9.1, 9.2, 9.3, 9.4],
])
def _div_sharded_weights(self):
return [
constant_op.constant([
[0.0, 0.1, 0.2, 0.3, 0.4],
[1.0, 1.1, 1.2, 1.3, 1.4],
]),
constant_op.constant([
[2.0, 2.1, 2.2, 2.3, 2.4],
[3.0, 3.1, 3.2, 3.3, 3.4],
]),
constant_op.constant([
[4.0, 4.1, 4.2, 4.3, 4.4],
[5.0, 5.1, 5.2, 5.3, 5.4],
]),
constant_op.constant([
[6.0, 6.1, 6.2, 6.3, 6.4],
[7.0, 7.1, 7.2, 7.3, 7.4],
]),
constant_op.constant([
[8.0, 8.1, 8.2, 8.3, 8.4],
[9.0, 9.1, 9.2, 9.3, 9.4],
]),
]
def _mod_sharded_weights(self):
return [
constant_op.constant([
[0.0, 0.1, 0.2, 0.3, 0.4],
[5.0, 5.1, 5.2, 5.3, 5.4],
]),
constant_op.constant([
[1.0, 1.1, 1.2, 1.3, 1.4],
[6.0, 6.1, 6.2, 6.3, 6.4],
]),
constant_op.constant([
[2.0, 2.1, 2.2, 2.3, 2.4],
[7.0, 7.1, 7.2, 7.3, 7.4],
]),
constant_op.constant([
[3.0, 3.1, 3.2, 3.3, 3.4],
[8.0, 8.1, 8.2, 8.3, 8.4],
]),
constant_op.constant([
[4.0, 4.1, 4.2, 4.3, 4.4],
[9.0, 9.1, 9.2, 9.3, 9.4],
]),
]
def _biases(self):
return constant_op.constant(
[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
def _div_sharded_biases(self):
return [
constant_op.constant([0.0, 0.1]),
constant_op.constant([0.2, 0.3]),
constant_op.constant([0.4, 0.5]),
constant_op.constant([0.6, 0.7]),
constant_op.constant([0.8, 0.9]),
]
def _mod_sharded_biases(self):
return [
constant_op.constant([0.0, 0.5]),
constant_op.constant([0.1, 0.6]),
constant_op.constant([0.2, 0.7]),
constant_op.constant([0.3, 0.8]),
constant_op.constant([0.4, 0.9]),
]
def _labels(self):
return constant_op.constant(
[[0, 1], [1, 2]],
shape=(self._batch_size, self._num_true),
name='labels',
dtype=dtypes.int64)
def _inputs(self):
return constant_op.constant(
[
[0., 1., 2., 3., 4.],
[10., 11., 12., 13., 14.],
],
shape=(self._batch_size, self._embed_dim),
name='inputs')
def testInvalidNumSampled0(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(
ValueError,
r'num_resampled \(3\) must be less than num_sampled \(3\)'):
sampling_ops.rank_sampled_softmax_loss(
weights=self._weights(),
biases=self._biases(),
labels=self._labels(),
inputs=self._inputs(),
num_sampled=3,
num_resampled=3,
num_classes=self._num_classes,
num_true=self._num_true,
sampled_values=None,
resampling_temperature=1.,
remove_accidental_hits=True,
partition_strategy='div')
def testInvalidNumSampled1(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(
ValueError,
r'num_resampled \(3\) must be less than num_sampled \(2\)'):
sampling_ops.rank_sampled_softmax_loss(
weights=self._weights(),
biases=self._biases(),
labels=self._labels(),
inputs=self._inputs(),
num_sampled=2,
num_resampled=3,
num_classes=self._num_classes,
num_true=self._num_true,
sampled_values=None,
resampling_temperature=1.,
remove_accidental_hits=True,
partition_strategy='div')
def testMissingPartitionStrategy(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError,
r'unsupported partition_strategy \(None\)'):
sampling_ops.rank_sampled_softmax_loss(
weights=self._weights(),
biases=self._biases(),
labels=self._labels(),
inputs=self._inputs(),
num_sampled=2,
num_resampled=1,
num_classes=self._num_classes,
num_true=self._num_true,
sampled_values=None,
resampling_temperature=1.,
remove_accidental_hits=True,
partition_strategy=None)
def _testCompareWithNN(self, weights, biases, partition_strategy):
with ops.Graph().as_default():
loss = sampling_ops.rank_sampled_softmax_loss(
weights=weights(),
biases=biases(),
labels=self._labels(),
inputs=self._inputs(),
num_sampled=self._num_sampled,
num_resampled=self._num_resampled,
num_classes=self._num_classes,
num_true=self._num_true,
sampled_values=self._sampled_values,
resampling_temperature=1.,
remove_accidental_hits=self._remove_accidental_hits,
partition_strategy=partition_strategy)
loss_nn = nn.sampled_softmax_loss(
weights=weights(),
biases=biases(),
labels=self._labels(),
inputs=self._inputs(),
num_sampled=self._num_resampled,
num_classes=self._num_classes,
num_true=self._num_true,
sampled_values=self._resampled_values,
remove_accidental_hits=self._remove_accidental_hits,
partition_strategy=partition_strategy)
with self.test_session() as sess:
loss_val = sess.run(loss)
loss_nn_val = sess.run(loss_nn)
self.assertAllClose(loss_val, loss_nn_val)
def testCompareWithNNUnsharded(self):
self._testCompareWithNN(self._weights, self._biases, 'div')
def testCompareWithNNShardWeightsDiv(self):
self._testCompareWithNN(self._div_sharded_weights, self._biases, 'div')
def testCompareWithNNShardWeightsAndBiasesDiv(self):
self._testCompareWithNN(self._div_sharded_weights, self._div_sharded_biases,
'div')
def testCompareWithNNShardWeightsMod(self):
self._testCompareWithNN(self._mod_sharded_weights, self._biases, 'mod')
def testCompareWithNNShardWeightsAndBiasesMod(self):
self._testCompareWithNN(self._mod_sharded_weights, self._mod_sharded_biases,
'mod')
def _testCompareWithNNTemperature(self, temperature, resampled):
weights = [[1., 2.], [3., 4.]] # two sampled classes
inputs = [[6., -5. / 2.], [-11., 21. / 2.]]
# Let w0, w1 = weights of sampled classes (biases set to 0 for simplicity)
# Let x0, x1 = inputs
# logits:
# w0.x0 = 1
# w0.x1 = 10
# w1.x0 = 8
# w1.x1 = 9
# Resampling 1 class with temperature = t will pick the larger of:
# exp(1/t) + exp(10/t) ==> w0, for values of t < 2.12
# exp(8/t) + exp(9/t) ==> w1, for values of t > 2.13
num_sampled = 2
num_resampled = 1
num_classes = 2
num_true = 1
sampled_values = [0, 1], [[1.], [1.]], [1., 1.]
resampled_values = [resampled], [[1.], [1.]], [1.]
remove_accidental_hits = False
with ops.Graph().as_default():
weights = constant_op.constant(weights)
biases = constant_op.constant([0., 0.])
labels = constant_op.constant([[0], [1]], dtype=dtypes.int64)
inputs = constant_op.constant(inputs)
loss = sampling_ops.rank_sampled_softmax_loss(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
num_sampled=num_sampled,
num_resampled=num_resampled,
num_classes=num_classes,
num_true=num_true,
sampled_values=sampled_values,
resampling_temperature=constant_op.constant(temperature),
remove_accidental_hits=remove_accidental_hits,
partition_strategy='div')
loss_nn = nn.sampled_softmax_loss(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
num_sampled=num_resampled,
num_classes=num_classes,
num_true=num_true,
sampled_values=resampled_values,
remove_accidental_hits=remove_accidental_hits,
partition_strategy='div')
with self.test_session() as sess:
loss_val = sess.run(loss)
loss_nn_val = sess.run(loss_nn)
self.assertAllClose(loss_val, loss_nn_val)
def testCompareWithNNTemperatureLo1(self):
self._testCompareWithNNTemperature(1., 0)
def testCompareWithNNTemperatureLo2(self):
self._testCompareWithNNTemperature(2.12, 0)
def testCompareWithNNTemperatureHi1(self):
self._testCompareWithNNTemperature(2.13, 1)
def testCompareWithNNTemperatureHi2(self):
self._testCompareWithNNTemperature(3., 1)
if __name__ == '__main__':
test.main()