Swap the input and label arguments in nce_loss

Change: 141244045
This commit is contained in:
Jianwei Xie 2016-12-06 16:31:01 -08:00 committed by TensorFlower Gardener
parent 059ccad4d4
commit 761b12ed82
3 changed files with 27 additions and 23 deletions

View File

@ -160,8 +160,12 @@ with graph.as_default():
# tf.nce_loss automatically draws a new sample of the negative labels each
# time we evaluate the loss.
loss = tf.reduce_mean(
tf.nn.nce_loss(nce_weights, nce_biases, embed, train_labels,
num_sampled, vocabulary_size))
tf.nn.nce_loss(weights=nce_weights,
biases=nce_biases,
labels=train_labels,
inputs=embed,
num_sampled=num_sampled,
num_classes=vocabulary_size))
# Construct the SGD optimizer using a learning rate of 1.0.
optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)

View File

@ -809,8 +809,8 @@ def _sum_rows(x):
def _compute_sampled_logits(weights,
biases,
inputs,
labels,
inputs,
num_sampled,
num_classes,
num_true=1,
@ -834,11 +834,11 @@ def _compute_sampled_logits(weights,
objects whose concatenation along dimension 0 has shape
`[num_classes, dim]`. The (possibly-partitioned) class embeddings.
biases: A `Tensor` of shape `[num_classes]`. The class biases.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
labels: A `Tensor` of type `int64` and shape `[batch_size,
num_true]`. The target classes. Note that this format differs from
the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
num_sampled: An `int`. The number of classes to randomly sample per batch.
num_classes: An `int`. The number of possible classes.
num_true: An `int`. The number of target classes per training example.
@ -975,8 +975,8 @@ def _compute_sampled_logits(weights,
def nce_loss(weights,
biases,
inputs,
labels,
inputs,
num_sampled,
num_classes,
num_true=1,
@ -1012,10 +1012,10 @@ def nce_loss(weights,
objects whose concatenation along dimension 0 has shape
[num_classes, dim]. The (possibly-partitioned) class embeddings.
biases: A `Tensor` of shape `[num_classes]`. The class biases.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
labels: A `Tensor` of type `int64` and shape `[batch_size,
num_true]`. The target classes.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
num_sampled: An `int`. The number of classes to randomly sample per batch.
num_classes: An `int`. The number of possible classes.
num_true: An `int`. The number of target classes per training example.
@ -1038,12 +1038,12 @@ def nce_loss(weights,
A `batch_size` 1-D tensor of per-example NCE losses.
"""
logits, labels = _compute_sampled_logits(
weights,
biases,
inputs,
labels,
num_sampled,
num_classes,
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
num_sampled=num_sampled,
num_classes=num_classes,
num_true=num_true,
sampled_values=sampled_values,
subtract_log_q=True,
@ -1114,12 +1114,12 @@ def sampled_softmax_loss(weights,
"""
logits, labels = _compute_sampled_logits(
weights,
biases,
inputs,
labels,
num_sampled,
num_classes,
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
num_sampled=num_sampled,
num_classes=num_classes,
num_true=num_true,
sampled_values=sampled_values,
subtract_log_q=True,

View File

@ -452,8 +452,8 @@ class ComputeSampledLogitsTest(tf.test.TestCase):
pred_logits_tf, pred_labels_tf = _compute_sampled_logits(
weights_tf,
biases_tf,
hidden_acts_tf,
labels_tf,
hidden_acts_tf,
num_sampled,
num_classes,
num_true,
@ -672,8 +672,8 @@ class ComputeSampledLogitsTest(tf.test.TestCase):
nce_loss_tf = tf.nn.nce_loss(
weights_tf,
biases_tf,
inputs_tf,
labels_tf,
inputs_tf,
num_sampled=1,
num_classes=self._num_classes,
num_true=1,
@ -685,8 +685,8 @@ class ComputeSampledLogitsTest(tf.test.TestCase):
nce_loss_tf = tf.nn.nce_loss(
[tf.constant(shard) for shard in sharded_weights],
biases_tf,
inputs_tf,
labels_tf,
inputs_tf,
num_sampled=1,
num_classes=self._num_classes,
num_true=1,