Add tf.contrib.nn.rank_sampled_softmax_loss, a variant of tf.nn.sampled_softmax_loss that has been shown to improve rank loss. Paper: https://arxiv.org/abs/1707.03073
PiperOrigin-RevId: 161702455
This commit is contained in:
parent
9aa0dcbf28
commit
c9d03a568a
tensorflow/contrib/nn
@ -7,6 +7,8 @@ exports_files(["LICENSE"])
|
|||||||
|
|
||||||
package(default_visibility = ["//visibility:public"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "nn_py",
|
name = "nn_py",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -14,15 +16,34 @@ py_library(
|
|||||||
"python/__init__.py",
|
"python/__init__.py",
|
||||||
"python/ops/__init__.py",
|
"python/ops/__init__.py",
|
||||||
"python/ops/cross_entropy.py",
|
"python/ops/cross_entropy.py",
|
||||||
|
"python/ops/sampling_ops.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:nn",
|
"//tensorflow/python:nn",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "sampling_ops_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/ops/sampling_ops_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":nn_py",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:nn",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "all_files",
|
name = "all_files",
|
||||||
srcs = glob(
|
srcs = glob(
|
||||||
|
243
tensorflow/contrib/nn/python/ops/sampling_ops.py
Normal file
243
tensorflow/contrib/nn/python/ops/sampling_ops.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Ops related to candidate sampling."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import embedding_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import nn
|
||||||
|
|
||||||
|
|
||||||
|
def _rank_resample(weights, biases, inputs, sampled_values, num_resampled,
|
||||||
|
resampling_temperature, partition_strategy):
|
||||||
|
"""A helper function for rank_sampled_softmax_loss.
|
||||||
|
|
||||||
|
This computes, for each i in `sampled_values`,
|
||||||
|
|
||||||
|
log(sum_j exp((w_i * x_j + b_i) / resampling_temperature))
|
||||||
|
|
||||||
|
where w_i, b_i are the weight and bias of the i-th class, repsectively,
|
||||||
|
and j ranges over the rows of `inputs`. For efficiency, we rearrange the
|
||||||
|
computation to
|
||||||
|
|
||||||
|
log(sum_j exp(w_i * (x_j / resampling_temperature))) +
|
||||||
|
b_i / resampling_temperature.
|
||||||
|
|
||||||
|
This translates to the following batched computation using tensorflow ops:
|
||||||
|
|
||||||
|
reduce_logsumexp(matmul(embeddings,
|
||||||
|
transpose(inputs / resampling_temperature))) +
|
||||||
|
biases / resampling_temperature
|
||||||
|
|
||||||
|
The computation of the first term is colocated with the embeddings using
|
||||||
|
`transform_fn` in `embedding_ops._embedding_lookup_and_transform`. The second
|
||||||
|
term, not the bottleneck, is computed at the worker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights: From `rank_sampled_softmax_loss`.
|
||||||
|
biases: From `rank_sampled_softmax_loss`.
|
||||||
|
inputs: From `rank_sampled_softmax_loss`.
|
||||||
|
sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`,
|
||||||
|
`sampled_expected_count`) returned by a `*_candidate_sampler` function.
|
||||||
|
num_resampled: An `int`. This many values are selected from
|
||||||
|
`sampled_values` using the adaptive resampling algorithm. The caller
|
||||||
|
must ensure that `num_resampled` is less than the size of
|
||||||
|
`sampled_values`.
|
||||||
|
resampling_temperature: A scalar `Tensor` with the temperature parameter
|
||||||
|
for the adaptive resampling algorithm.
|
||||||
|
partition_strategy: From `rank_sampled_softmax_loss`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (`resampled_candidates`, `true_expected_count`,
|
||||||
|
`resampled_expected_count`), similar to `sampled_values` but sampled
|
||||||
|
down to `num_resampled` values.
|
||||||
|
"""
|
||||||
|
# This code supports passing a Tensor for num_resampled, but since it is only
|
||||||
|
# called with an int, that's what we specify in the arg list. If this
|
||||||
|
# function is ever externalized, we should change the doc to support Tensor.
|
||||||
|
|
||||||
|
sampled, true_expected_count, sampled_expected_count = sampled_values
|
||||||
|
|
||||||
|
sampled = math_ops.cast(array_ops.stop_gradient(sampled), dtypes.int64)
|
||||||
|
true_expected_count = array_ops.stop_gradient(true_expected_count)
|
||||||
|
sampled_expected_count = array_ops.stop_gradient(sampled_expected_count)
|
||||||
|
|
||||||
|
reweighted_inputs = inputs / resampling_temperature
|
||||||
|
|
||||||
|
def logsumexp_logit(embeddings):
|
||||||
|
return math_ops.reduce_logsumexp(
|
||||||
|
math_ops.matmul(embeddings, reweighted_inputs, transpose_b=True),
|
||||||
|
axis=1,
|
||||||
|
keep_dims=False)
|
||||||
|
|
||||||
|
# Calling this protected form of embedding_lookup allows co-locating
|
||||||
|
# the logsumexp computation with the partitioned weights, which yields
|
||||||
|
# a large speedup in practice.
|
||||||
|
sampled_logits = embedding_ops._embedding_lookup_and_transform( # pylint: disable=protected-access
|
||||||
|
weights, sampled, partition_strategy, transform_fn=logsumexp_logit)
|
||||||
|
sampled_b = array_ops.reshape(
|
||||||
|
embedding_ops.embedding_lookup(biases, sampled, partition_strategy), [-1])
|
||||||
|
sampled_logits += sampled_b / resampling_temperature
|
||||||
|
|
||||||
|
_, resampled_indices = nn.top_k(sampled_logits, k=num_resampled, sorted=False)
|
||||||
|
resampled = array_ops.gather(sampled, indices=resampled_indices)
|
||||||
|
resampled_expected_count = array_ops.gather(
|
||||||
|
sampled_expected_count, indices=resampled_indices)
|
||||||
|
|
||||||
|
return resampled, true_expected_count, resampled_expected_count
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(ccolby): Before checkin, Add reference to TAPAS paper when in arxiv.org.
|
||||||
|
def rank_sampled_softmax_loss(weights,
|
||||||
|
biases,
|
||||||
|
labels,
|
||||||
|
inputs,
|
||||||
|
num_sampled,
|
||||||
|
num_resampled,
|
||||||
|
num_classes,
|
||||||
|
num_true,
|
||||||
|
sampled_values,
|
||||||
|
resampling_temperature,
|
||||||
|
remove_accidental_hits,
|
||||||
|
partition_strategy,
|
||||||
|
name=None):
|
||||||
|
"""Computes softmax loss using rank-based adaptive resampling.
|
||||||
|
|
||||||
|
This has been shown to improve rank loss after training compared to
|
||||||
|
@{tf.nn.sampled_softmax_loss}. For a description of the algorithm and some
|
||||||
|
experimental results, please see: [TAPAS: Two-pass Approximate Adaptive
|
||||||
|
Sampling for Softmax](https://arxiv.org/abs/1707.03073).
|
||||||
|
|
||||||
|
Sampling follows two phases:
|
||||||
|
* In the first phase, `num_sampled` classes are selected using
|
||||||
|
@{tf.nn.learned_unigram_candidate_sampler} or supplied `sampled_values`.
|
||||||
|
The logits are calculated on those sampled classes. This phases is
|
||||||
|
similar to @{tf.nn.sampled_softmax_loss}.
|
||||||
|
* In the second phase, the `num_resampled` classes with highest predicted
|
||||||
|
probability are kept. Probabilities are
|
||||||
|
`LogSumExp(logits / resampling_temperature)`, where the sum is over
|
||||||
|
`inputs`.
|
||||||
|
|
||||||
|
The `resampling_temperature` parameter controls the "adaptiveness" of the
|
||||||
|
resampling. At lower temperatures, resampling is more adaptive because it
|
||||||
|
picks more candidates close to the predicted classes. A common strategy is
|
||||||
|
to decrease the temperature as training proceeds.
|
||||||
|
|
||||||
|
See @{tf.nn.sampled_softmax_loss} for more documentation on sampling and
|
||||||
|
for typical default values for some of the parameters.
|
||||||
|
|
||||||
|
This operation is for training only. It is generally an underestimate of
|
||||||
|
the full softmax loss.
|
||||||
|
|
||||||
|
A common use case is to use this method for training, and calculate the full
|
||||||
|
softmax loss for evaluation or inference. In this case, you must set
|
||||||
|
`partition_strategy="div"` for the two losses to be consistent, as in the
|
||||||
|
following example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
if mode == "train":
|
||||||
|
loss = rank_sampled_softmax_loss(
|
||||||
|
weights=weights,
|
||||||
|
biases=biases,
|
||||||
|
labels=labels,
|
||||||
|
inputs=inputs,
|
||||||
|
...,
|
||||||
|
partition_strategy="div")
|
||||||
|
elif mode == "eval":
|
||||||
|
logits = tf.matmul(inputs, tf.transpose(weights))
|
||||||
|
logits = tf.nn.bias_add(logits, biases)
|
||||||
|
labels_one_hot = tf.one_hot(labels, n_classes)
|
||||||
|
loss = tf.nn.softmax_cross_entropy_with_logits(
|
||||||
|
labels=labels_one_hot,
|
||||||
|
logits=logits)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights: A `Tensor` or `PartitionedVariable` of shape `[num_classes, dim]`,
|
||||||
|
or a list of `Tensor` objects whose concatenation along dimension 0
|
||||||
|
has shape [num_classes, dim]. The (possibly-sharded) class embeddings.
|
||||||
|
biases: A `Tensor` or `PartitionedVariable` of shape `[num_classes]`.
|
||||||
|
The (possibly-sharded) class biases.
|
||||||
|
labels: A `Tensor` of type `int64` and shape `[batch_size,
|
||||||
|
num_true]`. The target classes. Note that this format differs from
|
||||||
|
the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
|
||||||
|
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
|
||||||
|
activations of the input network.
|
||||||
|
num_sampled: An `int`. The number of classes to randomly sample per batch.
|
||||||
|
num_resampled: An `int`. The number of classes to select from the
|
||||||
|
`num_sampled` classes using the adaptive resampling algorithm. Must be
|
||||||
|
less than `num_sampled`.
|
||||||
|
num_classes: An `int`. The number of possible classes.
|
||||||
|
num_true: An `int`. The number of target classes per training example.
|
||||||
|
sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`,
|
||||||
|
`sampled_expected_count`) returned by a `*_candidate_sampler` function.
|
||||||
|
If None, default to `nn.learned_unigram_candidate_sampler`.
|
||||||
|
resampling_temperature: A scalar `Tensor` with the temperature parameter
|
||||||
|
for the adaptive resampling algorithm.
|
||||||
|
remove_accidental_hits: A `bool`. Whether to remove "accidental hits"
|
||||||
|
where a sampled class equals one of the target classes.
|
||||||
|
partition_strategy: A string specifying the partitioning strategy, relevant
|
||||||
|
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
|
||||||
|
See @{tf.nn.embedding_lookup} for more details.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `batch_size` 1-D tensor of per-example sampled softmax losses.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `num_sampled <= num_resampled`.
|
||||||
|
"""
|
||||||
|
if num_sampled > num_classes:
|
||||||
|
raise ValueError("num_sampled ({}) cannot be greater than num_classes ({})".
|
||||||
|
format(num_sampled, num_classes))
|
||||||
|
if num_sampled <= num_resampled:
|
||||||
|
raise ValueError("num_resampled ({}) must be less than num_sampled ({})".
|
||||||
|
format(num_resampled, num_sampled))
|
||||||
|
if partition_strategy not in ("div", "mod"):
|
||||||
|
raise ValueError(
|
||||||
|
"unsupported partition_strategy ({})".format(partition_strategy))
|
||||||
|
with ops.name_scope(name, "rank_sampled_softmax_loss", [
|
||||||
|
weights, biases, labels, inputs, sampled_values, resampling_temperature
|
||||||
|
]) as name:
|
||||||
|
if not sampled_values:
|
||||||
|
sampled_values = nn.learned_unigram_candidate_sampler(
|
||||||
|
true_classes=labels,
|
||||||
|
num_true=num_true,
|
||||||
|
num_sampled=num_sampled,
|
||||||
|
unique=True,
|
||||||
|
range_max=num_classes)
|
||||||
|
# From sampled_values, select the top num_resampled values using the
|
||||||
|
# adaptive rank resampling strategy.
|
||||||
|
resampled_values = _rank_resample(weights, biases, inputs, sampled_values,
|
||||||
|
num_resampled, resampling_temperature,
|
||||||
|
partition_strategy)
|
||||||
|
return nn.sampled_softmax_loss(
|
||||||
|
weights=weights,
|
||||||
|
biases=biases,
|
||||||
|
labels=labels,
|
||||||
|
inputs=inputs,
|
||||||
|
num_sampled=num_resampled,
|
||||||
|
num_classes=num_classes,
|
||||||
|
num_true=num_true,
|
||||||
|
sampled_values=resampled_values,
|
||||||
|
remove_accidental_hits=remove_accidental_hits,
|
||||||
|
partition_strategy=partition_strategy,
|
||||||
|
name=name)
|
322
tensorflow/contrib/nn/python/ops/sampling_ops_test.py
Normal file
322
tensorflow/contrib/nn/python/ops/sampling_ops_test.py
Normal file
@ -0,0 +1,322 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for sampling_ops.py."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.nn.python.ops import sampling_ops
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import nn
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class RankSampledSoftmaxLossTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._sampled = [3, 4, 5, 6, 7]
|
||||||
|
self._num_sampled = len(self._sampled)
|
||||||
|
# Because values of all matrices increase with indices, logits increase with
|
||||||
|
# class id. So, for the above sampled classes, adaptive sampling will select
|
||||||
|
# these resampled classes.
|
||||||
|
self._resampled = [5, 6, 7]
|
||||||
|
self._num_resampled = len(self._resampled)
|
||||||
|
self._num_classes = 10
|
||||||
|
self._num_true = 2
|
||||||
|
self._sampled_values = (self._sampled, [[0.5], [0.5]],
|
||||||
|
[0.5, 0.5, 0.5, 0.5, 0.5])
|
||||||
|
self._resampled_values = (self._resampled, [[0.5], [0.5]], [0.5, 0.5, 0.5])
|
||||||
|
self._remove_accidental_hits = False
|
||||||
|
self._embed_dim = 5
|
||||||
|
self._batch_size = 2
|
||||||
|
|
||||||
|
def _weights(self):
|
||||||
|
return constant_op.constant([
|
||||||
|
[0.0, 0.1, 0.2, 0.3, 0.4],
|
||||||
|
[1.0, 1.1, 1.2, 1.3, 1.4],
|
||||||
|
[2.0, 2.1, 2.2, 2.3, 2.4],
|
||||||
|
[3.0, 3.1, 3.2, 3.3, 3.4],
|
||||||
|
[4.0, 4.1, 4.2, 4.3, 4.4],
|
||||||
|
[5.0, 5.1, 5.2, 5.3, 5.4],
|
||||||
|
[6.0, 6.1, 6.2, 6.3, 6.4],
|
||||||
|
[7.0, 7.1, 7.2, 7.3, 7.4],
|
||||||
|
[8.0, 8.1, 8.2, 8.3, 8.4],
|
||||||
|
[9.0, 9.1, 9.2, 9.3, 9.4],
|
||||||
|
])
|
||||||
|
|
||||||
|
def _div_sharded_weights(self):
|
||||||
|
return [
|
||||||
|
constant_op.constant([
|
||||||
|
[0.0, 0.1, 0.2, 0.3, 0.4],
|
||||||
|
[1.0, 1.1, 1.2, 1.3, 1.4],
|
||||||
|
]),
|
||||||
|
constant_op.constant([
|
||||||
|
[2.0, 2.1, 2.2, 2.3, 2.4],
|
||||||
|
[3.0, 3.1, 3.2, 3.3, 3.4],
|
||||||
|
]),
|
||||||
|
constant_op.constant([
|
||||||
|
[4.0, 4.1, 4.2, 4.3, 4.4],
|
||||||
|
[5.0, 5.1, 5.2, 5.3, 5.4],
|
||||||
|
]),
|
||||||
|
constant_op.constant([
|
||||||
|
[6.0, 6.1, 6.2, 6.3, 6.4],
|
||||||
|
[7.0, 7.1, 7.2, 7.3, 7.4],
|
||||||
|
]),
|
||||||
|
constant_op.constant([
|
||||||
|
[8.0, 8.1, 8.2, 8.3, 8.4],
|
||||||
|
[9.0, 9.1, 9.2, 9.3, 9.4],
|
||||||
|
]),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _mod_sharded_weights(self):
|
||||||
|
return [
|
||||||
|
constant_op.constant([
|
||||||
|
[0.0, 0.1, 0.2, 0.3, 0.4],
|
||||||
|
[5.0, 5.1, 5.2, 5.3, 5.4],
|
||||||
|
]),
|
||||||
|
constant_op.constant([
|
||||||
|
[1.0, 1.1, 1.2, 1.3, 1.4],
|
||||||
|
[6.0, 6.1, 6.2, 6.3, 6.4],
|
||||||
|
]),
|
||||||
|
constant_op.constant([
|
||||||
|
[2.0, 2.1, 2.2, 2.3, 2.4],
|
||||||
|
[7.0, 7.1, 7.2, 7.3, 7.4],
|
||||||
|
]),
|
||||||
|
constant_op.constant([
|
||||||
|
[3.0, 3.1, 3.2, 3.3, 3.4],
|
||||||
|
[8.0, 8.1, 8.2, 8.3, 8.4],
|
||||||
|
]),
|
||||||
|
constant_op.constant([
|
||||||
|
[4.0, 4.1, 4.2, 4.3, 4.4],
|
||||||
|
[9.0, 9.1, 9.2, 9.3, 9.4],
|
||||||
|
]),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _biases(self):
|
||||||
|
return constant_op.constant(
|
||||||
|
[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
|
||||||
|
|
||||||
|
def _div_sharded_biases(self):
|
||||||
|
return [
|
||||||
|
constant_op.constant([0.0, 0.1]),
|
||||||
|
constant_op.constant([0.2, 0.3]),
|
||||||
|
constant_op.constant([0.4, 0.5]),
|
||||||
|
constant_op.constant([0.6, 0.7]),
|
||||||
|
constant_op.constant([0.8, 0.9]),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _mod_sharded_biases(self):
|
||||||
|
return [
|
||||||
|
constant_op.constant([0.0, 0.5]),
|
||||||
|
constant_op.constant([0.1, 0.6]),
|
||||||
|
constant_op.constant([0.2, 0.7]),
|
||||||
|
constant_op.constant([0.3, 0.8]),
|
||||||
|
constant_op.constant([0.4, 0.9]),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _labels(self):
|
||||||
|
return constant_op.constant(
|
||||||
|
[[0, 1], [1, 2]],
|
||||||
|
shape=(self._batch_size, self._num_true),
|
||||||
|
name='labels',
|
||||||
|
dtype=dtypes.int64)
|
||||||
|
|
||||||
|
def _inputs(self):
|
||||||
|
return constant_op.constant(
|
||||||
|
[
|
||||||
|
[0., 1., 2., 3., 4.],
|
||||||
|
[10., 11., 12., 13., 14.],
|
||||||
|
],
|
||||||
|
shape=(self._batch_size, self._embed_dim),
|
||||||
|
name='inputs')
|
||||||
|
|
||||||
|
def testInvalidNumSampled0(self):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
r'num_resampled \(3\) must be less than num_sampled \(3\)'):
|
||||||
|
sampling_ops.rank_sampled_softmax_loss(
|
||||||
|
weights=self._weights(),
|
||||||
|
biases=self._biases(),
|
||||||
|
labels=self._labels(),
|
||||||
|
inputs=self._inputs(),
|
||||||
|
num_sampled=3,
|
||||||
|
num_resampled=3,
|
||||||
|
num_classes=self._num_classes,
|
||||||
|
num_true=self._num_true,
|
||||||
|
sampled_values=None,
|
||||||
|
resampling_temperature=1.,
|
||||||
|
remove_accidental_hits=True,
|
||||||
|
partition_strategy='div')
|
||||||
|
|
||||||
|
def testInvalidNumSampled1(self):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
r'num_resampled \(3\) must be less than num_sampled \(2\)'):
|
||||||
|
sampling_ops.rank_sampled_softmax_loss(
|
||||||
|
weights=self._weights(),
|
||||||
|
biases=self._biases(),
|
||||||
|
labels=self._labels(),
|
||||||
|
inputs=self._inputs(),
|
||||||
|
num_sampled=2,
|
||||||
|
num_resampled=3,
|
||||||
|
num_classes=self._num_classes,
|
||||||
|
num_true=self._num_true,
|
||||||
|
sampled_values=None,
|
||||||
|
resampling_temperature=1.,
|
||||||
|
remove_accidental_hits=True,
|
||||||
|
partition_strategy='div')
|
||||||
|
|
||||||
|
def testMissingPartitionStrategy(self):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
r'unsupported partition_strategy \(None\)'):
|
||||||
|
sampling_ops.rank_sampled_softmax_loss(
|
||||||
|
weights=self._weights(),
|
||||||
|
biases=self._biases(),
|
||||||
|
labels=self._labels(),
|
||||||
|
inputs=self._inputs(),
|
||||||
|
num_sampled=2,
|
||||||
|
num_resampled=1,
|
||||||
|
num_classes=self._num_classes,
|
||||||
|
num_true=self._num_true,
|
||||||
|
sampled_values=None,
|
||||||
|
resampling_temperature=1.,
|
||||||
|
remove_accidental_hits=True,
|
||||||
|
partition_strategy=None)
|
||||||
|
|
||||||
|
def _testCompareWithNN(self, weights, biases, partition_strategy):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
loss = sampling_ops.rank_sampled_softmax_loss(
|
||||||
|
weights=weights(),
|
||||||
|
biases=biases(),
|
||||||
|
labels=self._labels(),
|
||||||
|
inputs=self._inputs(),
|
||||||
|
num_sampled=self._num_sampled,
|
||||||
|
num_resampled=self._num_resampled,
|
||||||
|
num_classes=self._num_classes,
|
||||||
|
num_true=self._num_true,
|
||||||
|
sampled_values=self._sampled_values,
|
||||||
|
resampling_temperature=1.,
|
||||||
|
remove_accidental_hits=self._remove_accidental_hits,
|
||||||
|
partition_strategy=partition_strategy)
|
||||||
|
loss_nn = nn.sampled_softmax_loss(
|
||||||
|
weights=weights(),
|
||||||
|
biases=biases(),
|
||||||
|
labels=self._labels(),
|
||||||
|
inputs=self._inputs(),
|
||||||
|
num_sampled=self._num_resampled,
|
||||||
|
num_classes=self._num_classes,
|
||||||
|
num_true=self._num_true,
|
||||||
|
sampled_values=self._resampled_values,
|
||||||
|
remove_accidental_hits=self._remove_accidental_hits,
|
||||||
|
partition_strategy=partition_strategy)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
loss_val = sess.run(loss)
|
||||||
|
loss_nn_val = sess.run(loss_nn)
|
||||||
|
|
||||||
|
self.assertAllClose(loss_val, loss_nn_val)
|
||||||
|
|
||||||
|
def testCompareWithNNUnsharded(self):
|
||||||
|
self._testCompareWithNN(self._weights, self._biases, 'div')
|
||||||
|
|
||||||
|
def testCompareWithNNShardWeightsDiv(self):
|
||||||
|
self._testCompareWithNN(self._div_sharded_weights, self._biases, 'div')
|
||||||
|
|
||||||
|
def testCompareWithNNShardWeightsAndBiasesDiv(self):
|
||||||
|
self._testCompareWithNN(self._div_sharded_weights, self._div_sharded_biases,
|
||||||
|
'div')
|
||||||
|
|
||||||
|
def testCompareWithNNShardWeightsMod(self):
|
||||||
|
self._testCompareWithNN(self._mod_sharded_weights, self._biases, 'mod')
|
||||||
|
|
||||||
|
def testCompareWithNNShardWeightsAndBiasesMod(self):
|
||||||
|
self._testCompareWithNN(self._mod_sharded_weights, self._mod_sharded_biases,
|
||||||
|
'mod')
|
||||||
|
|
||||||
|
def _testCompareWithNNTemperature(self, temperature, resampled):
|
||||||
|
weights = [[1., 2.], [3., 4.]] # two sampled classes
|
||||||
|
inputs = [[6., -5. / 2.], [-11., 21. / 2.]]
|
||||||
|
# Let w0, w1 = weights of sampled classes (biases set to 0 for simplicity)
|
||||||
|
# Let x0, x1 = inputs
|
||||||
|
# logits:
|
||||||
|
# w0.x0 = 1
|
||||||
|
# w0.x1 = 10
|
||||||
|
# w1.x0 = 8
|
||||||
|
# w1.x1 = 9
|
||||||
|
# Resampling 1 class with temperature = t will pick the larger of:
|
||||||
|
# exp(1/t) + exp(10/t) ==> w0, for values of t < 2.12
|
||||||
|
# exp(8/t) + exp(9/t) ==> w1, for values of t > 2.13
|
||||||
|
num_sampled = 2
|
||||||
|
num_resampled = 1
|
||||||
|
num_classes = 2
|
||||||
|
num_true = 1
|
||||||
|
sampled_values = [0, 1], [[1.], [1.]], [1., 1.]
|
||||||
|
resampled_values = [resampled], [[1.], [1.]], [1.]
|
||||||
|
remove_accidental_hits = False
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
weights = constant_op.constant(weights)
|
||||||
|
biases = constant_op.constant([0., 0.])
|
||||||
|
labels = constant_op.constant([[0], [1]], dtype=dtypes.int64)
|
||||||
|
inputs = constant_op.constant(inputs)
|
||||||
|
loss = sampling_ops.rank_sampled_softmax_loss(
|
||||||
|
weights=weights,
|
||||||
|
biases=biases,
|
||||||
|
labels=labels,
|
||||||
|
inputs=inputs,
|
||||||
|
num_sampled=num_sampled,
|
||||||
|
num_resampled=num_resampled,
|
||||||
|
num_classes=num_classes,
|
||||||
|
num_true=num_true,
|
||||||
|
sampled_values=sampled_values,
|
||||||
|
resampling_temperature=constant_op.constant(temperature),
|
||||||
|
remove_accidental_hits=remove_accidental_hits,
|
||||||
|
partition_strategy='div')
|
||||||
|
loss_nn = nn.sampled_softmax_loss(
|
||||||
|
weights=weights,
|
||||||
|
biases=biases,
|
||||||
|
labels=labels,
|
||||||
|
inputs=inputs,
|
||||||
|
num_sampled=num_resampled,
|
||||||
|
num_classes=num_classes,
|
||||||
|
num_true=num_true,
|
||||||
|
sampled_values=resampled_values,
|
||||||
|
remove_accidental_hits=remove_accidental_hits,
|
||||||
|
partition_strategy='div')
|
||||||
|
with self.test_session() as sess:
|
||||||
|
loss_val = sess.run(loss)
|
||||||
|
loss_nn_val = sess.run(loss_nn)
|
||||||
|
|
||||||
|
self.assertAllClose(loss_val, loss_nn_val)
|
||||||
|
|
||||||
|
def testCompareWithNNTemperatureLo1(self):
|
||||||
|
self._testCompareWithNNTemperature(1., 0)
|
||||||
|
|
||||||
|
def testCompareWithNNTemperatureLo2(self):
|
||||||
|
self._testCompareWithNNTemperature(2.12, 0)
|
||||||
|
|
||||||
|
def testCompareWithNNTemperatureHi1(self):
|
||||||
|
self._testCompareWithNNTemperature(2.13, 1)
|
||||||
|
|
||||||
|
def testCompareWithNNTemperatureHi2(self):
|
||||||
|
self._testCompareWithNNTemperature(3., 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
Loading…
Reference in New Issue
Block a user