Change arg order for {softmax,sparse_softmax,sigmoid}_cross_entropy_with_logits to be (labels, predictions), and force use of named args to avoid accidents.

Change: 143629623
This commit is contained in:
Martin Wicke 2017-01-04 21:25:34 -08:00 committed by TensorFlower Gardener
parent d9541696b0
commit 333dc32ff7
34 changed files with 165 additions and 105 deletions
tensorflow
contrib
distributions/python/ops
layers/python/layers
learn/python/learn
legacy_seq2seq/python/ops
linear_optimizer/python/ops
losses/python/losses
nn/python/ops
tensor_forest/hybrid/python
examples
g3doc
api_docs/python
how_tos
meta_graph
summaries_and_tensorboard
tutorials/mnist
python
tools/docker/notebooks

View File

@ -147,7 +147,7 @@ class Bernoulli(distribution.Distribution):
distribution_util.same_dynamic_shape(logits, event),
lambda: (logits, event),
lambda: broadcast(logits, event))
return -nn.sigmoid_cross_entropy_with_logits(logits, event)
return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
def _prob(self, event):
return math_ops.exp(self._log_prob(event))

View File

@ -202,7 +202,8 @@ class Categorical(distribution.Distribution):
logits_shape = array_ops.shape(logits)[:-1]
k *= array_ops.ones(logits_shape, dtype=k.dtype)
k.set_shape(tensor_shape.TensorShape(logits.get_shape()[:-1]))
return -nn_ops.sparse_softmax_cross_entropy_with_logits(logits, k)
return -nn_ops.sparse_softmax_cross_entropy_with_logits(labels=k,
logits=logits)
def _prob(self, k):
return math_ops.exp(self._log_prob(k))
@ -214,7 +215,8 @@ class Categorical(distribution.Distribution):
logits_2d = array_ops.reshape(self.logits, [-1, self.num_classes])
histogram_2d = nn_ops.softmax(logits_2d)
ret = array_ops.reshape(
nn_ops.softmax_cross_entropy_with_logits(logits_2d, histogram_2d),
nn_ops.softmax_cross_entropy_with_logits(labels=histogram_2d,
logits=logits_2d),
self.batch_shape())
ret.set_shape(self.get_batch_shape())
return ret

View File

@ -215,7 +215,8 @@ class _OneHotCategorical(distribution.Distribution):
else:
logits_2d = array_ops.reshape(logits, [-1, self.num_classes])
x_2d = array_ops.reshape(x, [-1, self.num_classes])
ret = -nn_ops.softmax_cross_entropy_with_logits(logits_2d, x_2d)
ret = -nn_ops.softmax_cross_entropy_with_logits(labels=x_2d,
logits=logits_2d)
ret = array_ops.reshape(ret, logits_shape)
return ret
@ -229,7 +230,8 @@ class _OneHotCategorical(distribution.Distribution):
logits_2d = array_ops.reshape(self.logits, [-1, self.num_classes])
histogram_2d = nn_ops.softmax(logits_2d)
ret = array_ops.reshape(
nn_ops.softmax_cross_entropy_with_logits(logits_2d, histogram_2d),
nn_ops.softmax_cross_entropy_with_logits(labels=histogram_2d,
logits=logits_2d),
self.batch_shape())
ret.set_shape(self.get_batch_shape())
return ret

View File

@ -491,7 +491,8 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
columns_to_tensors=columns_to_tensor,
feature_columns=feature_columns,
num_outputs=1)
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
logits=logits)
```
Args:

View File

@ -406,8 +406,8 @@ def _log_loss_with_two_classes(logits, target):
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
if len(target.get_shape()) == 1:
target = array_ops.expand_dims(target, dim=[1])
loss_vec = nn.sigmoid_cross_entropy_with_logits(logits,
math_ops.to_float(target))
loss_vec = nn.sigmoid_cross_entropy_with_logits(
labels=math_ops.to_float(target), logits=logits)
return loss_vec
@ -419,7 +419,8 @@ def _softmax_cross_entropy_loss(logits, target):
# sparse_softmax_cross_entropy_with_logits requires [batch_size] target.
if len(target.get_shape()) == 2:
target = array_ops.squeeze(target, squeeze_dims=[1])
loss_vec = nn.sparse_softmax_cross_entropy_with_logits(logits, target)
loss_vec = nn.sparse_softmax_cross_entropy_with_logits(
labels=target, logits=logits)
return loss_vec

View File

@ -466,7 +466,7 @@ def _log_loss_with_two_classes(logits, labels):
if len(labels.get_shape()) == 1:
labels = array_ops.expand_dims(labels, dim=(1,))
return nn.sigmoid_cross_entropy_with_logits(
logits, math_ops.to_float(labels), name=name)
labels=math_ops.to_float(labels), logits=logits, name=name)
def _one_class_to_two_class_logits(logits):
@ -669,7 +669,7 @@ def _softmax_cross_entropy_loss(logits, labels):
if len(labels.get_shape()) == 2:
labels = array_ops.squeeze(labels, squeeze_dims=(1,))
return nn.sparse_softmax_cross_entropy_with_logits(
logits, labels, name=name)
labels=labels, logits=logits, name=name)
class _MultiClassHead(_Head):
@ -1461,7 +1461,7 @@ def _sigmoid_cross_entropy_loss(logits, labels):
(logits, labels)) as name:
# sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels.
return nn.sigmoid_cross_entropy_with_logits(
logits, math_ops.to_float(labels), name=name)
labels=math_ops.to_float(labels), logits=logits, name=name)
def _float_weights_or_none(weights):

View File

@ -47,7 +47,7 @@ def sequence_classifier(decoding, labels, sampling_decoding=None, name=None):
predictions, xent_list = [], []
for i, pred in enumerate(decoding):
xent_list.append(nn.softmax_cross_entropy_with_logits(
pred, labels[i],
labels=labels[i], logits=pred,
name="sequence_loss/xent_raw{0}".format(i)))
if sampling_decoding:
predictions.append(nn.softmax(sampling_decoding[i]))

View File

@ -1062,7 +1062,7 @@ def sequence_loss_by_example(logits,
# violates our general scalar strictness policy.
target = array_ops.reshape(target, [-1])
crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
logits=logit, labels=target)
labels=target, logits=logit)
else:
crossent = softmax_loss_function(target, logit)
log_perp_list.append(crossent * weight)

View File

@ -456,10 +456,10 @@ class SdcaModel(object):
dtypes.float64)
if self._options['loss_type'] == 'logistic_loss':
return math_ops.reduce_sum(
math_ops.multiply(
sigmoid_cross_entropy_with_logits(predictions, labels),
weights)) / math_ops.reduce_sum(weights)
return math_ops.reduce_sum(math_ops.multiply(
sigmoid_cross_entropy_with_logits(labels=labels,
logits=predictions),
weights)) / math_ops.reduce_sum(weights)
if self._options['loss_type'] in ['hinge_loss', 'smooth_hinge_loss']:
# hinge_loss = max{0, 1 - y_i w*x} where y_i \in {-1, 1}. So, we need to

View File

@ -340,7 +340,8 @@ def sigmoid_cross_entropy(
multi_class_labels = (multi_class_labels * (1 - label_smoothing) +
0.5 * label_smoothing)
losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels,
losses = nn.sigmoid_cross_entropy_with_logits(labels=multi_class_labels,
logits=logits,
name="xentropy")
return compute_weighted_loss(losses, weights, scope=scope)
@ -387,7 +388,8 @@ def softmax_cross_entropy(
smooth_negatives = label_smoothing / num_classes
onehot_labels = onehot_labels * smooth_positives + smooth_negatives
losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels,
losses = nn.softmax_cross_entropy_with_logits(labels=onehot_labels,
logits=logits,
name="xentropy")
return compute_weighted_loss(losses, weights, scope=scope)
@ -421,7 +423,8 @@ def sparse_softmax_cross_entropy(logits, labels, weights=1.0, scope=None):
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
weights = array_ops.squeeze(weights)
losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels,
losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
logits=logits,
name="xentropy")
return compute_weighted_loss(losses, weights, scope=scope)

View File

@ -65,7 +65,7 @@ def deprecated_flipped_softmax_cross_entropy_with_logits(logits,
softmax cross entropy loss.
"""
return nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels, dim=dim, name=name)
labels=labels, logits=logits, dim=dim, name=name)
# TODO(b/33392402): Formally deprecate this API.
@ -119,7 +119,7 @@ def deprecated_flipped_sparse_softmax_cross_entropy_with_logits(logits,
of the labels is not equal to the rank of the labels minus one.
"""
return nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels, name=name)
labels=labels, logits=logits, name=name)
# TODO(b/33392402): Formally deprecate this API.
@ -174,4 +174,4 @@ def deprecated_flipped_sigmoid_cross_entropy_with_logits(logits,
ValueError: If `logits` and `targets` do not have the same shape.
"""
return nn.sigmoid_cross_entropy_with_logits(
logits=logits, targets=targets, name=name)
labels=targets, logits=logits, name=name)

View File

@ -117,8 +117,8 @@ class HybridModel(object):
else:
loss = math_ops.reduce_mean(
nn_ops.sparse_softmax_cross_entropy_with_logits(
self.training_inference_graph(data),
array_ops.squeeze(math_ops.to_int32(labels))),
labels=array_ops.squeeze(math_ops.to_int32(labels)),
logits=self.training_inference_graph(data)),
name="loss")
if self.regularizer:
loss += layers.apply_regularization(self.regularizer,

View File

@ -723,7 +723,7 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
with tf.name_scope('cross_entropy'):
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
logits, ground_truth_input)
labels=ground_truth_input, logits=logits)
with tf.name_scope('total'):
cross_entropy_mean = tf.reduce_mean(cross_entropy)
tf.summary.scalar('cross_entropy', cross_entropy_mean)

View File

@ -95,9 +95,8 @@ def loss(logits, labels):
"""
labels = tf.to_int64(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits, labels, name='xentropy')
loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
return loss
labels=labels, logits=logits, name='xentropy')
return tf.reduce_mean(cross_entropy, name='xentropy_mean')
def training(loss, learning_rate):

View File

@ -54,7 +54,8 @@ def main(_):
#
# So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
# outputs of 'y', and then average across the batch.
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.InteractiveSession()

View File

@ -119,7 +119,7 @@ def train():
# So here we use tf.nn.softmax_cross_entropy_with_logits on the
# raw outputs of the nn_layer above, and then average across
# the batch.
diff = tf.nn.softmax_cross_entropy_with_logits(y, y_)
diff = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)
with tf.name_scope('total'):
cross_entropy = tf.reduce_mean(diff)
tf.summary.scalar('cross_entropy', cross_entropy)

View File

@ -271,7 +271,7 @@
" # cross-entropy across all training examples: that's our loss.\n",
" logits = tf.matmul(tf_train_dataset, weights) + biases\n",
" loss = tf.reduce_mean(\n",
" tf.nn.softmax_cross_entropy_with_logits(logits, tf_train_labels))\n",
" tf.nn.softmax_cross_entropy_with_logits(labels=tf_train_labels, logits=logits))\n",
" \n",
" # Optimizer.\n",
" # We are going to find the minimum of this loss using gradient descent.\n",
@ -448,7 +448,7 @@
" # Training computation.\n",
" logits = tf.matmul(tf_train_dataset, weights) + biases\n",
" loss = tf.reduce_mean(\n",
" tf.nn.softmax_cross_entropy_with_logits(logits, tf_train_labels))\n",
" tf.nn.softmax_cross_entropy_with_logits(labels=tf_train_labels, logits=logits))\n",
" \n",
" # Optimizer.\n",
" optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(loss)\n",

View File

@ -286,7 +286,7 @@
" # Training computation.\n",
" logits = model(tf_train_dataset)\n",
" loss = tf.reduce_mean(\n",
" tf.nn.softmax_cross_entropy_with_logits(logits, tf_train_labels))\n",
" tf.nn.softmax_cross_entropy_with_logits(labels=tf_train_labels, logits=logits))\n",
" \n",
" # Optimizer.\n",
" optimizer = tf.train.GradientDescentOptimizer(0.05).minimize(loss)\n",

View File

@ -576,7 +576,7 @@
" logits = tf.nn.xw_plus_b(tf.concat_v2(outputs, 0), w, b)\n",
" loss = tf.reduce_mean(\n",
" tf.nn.softmax_cross_entropy_with_logits(\n",
" logits, tf.concat_v2(train_labels, 0)))\n",
" labels=tf.concat_v2(train_labels, 0), logits=logits))\n",
"\n",
" # Optimizer.\n",
" global_step = tf.Variable(0)\n",

View File

@ -1768,7 +1768,8 @@ Example:
columns_to_tensors=columns_to_tensor,
feature_columns=feature_columns,
num_outputs=1)
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
logits=logits)
```
##### Args:

View File

@ -20,7 +20,8 @@ Example:
columns_to_tensors=columns_to_tensor,
feature_columns=feature_columns,
num_outputs=1)
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
logits=logits)
```
##### Args:

View File

@ -227,9 +227,8 @@ Here are some of the typical usage models:
onehot_labels = tf.sparse_to_dense(
concated, tf.pack([batch_size, 10]), 1.0, 0.0)
logits = tf.get_collection("logits")[0]
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits,
onehot_labels,
name="xentropy")
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
labels=onehot_labels, logits=logits, name="xentropy")
loss = tf.reduce_mean(cross_entropy, name="xentropy_mean")
tf.summary.scalar('loss', loss)

View File

@ -135,7 +135,7 @@ with tf.name_scope('cross_entropy'):
# So here we use tf.nn.softmax_cross_entropy_with_logits on the
# raw outputs of the nn_layer above, and then average across
# the batch.
diff = tf.nn.softmax_cross_entropy_with_logits(y, y_)
diff = tf.nn.softmax_cross_entropy_with_logits(targets=y_, logits=y)
with tf.name_scope('total'):
cross_entropy = tf.reduce_mean(diff)
tf.summary.scalar('cross_entropy', cross_entropy)

View File

@ -173,7 +173,8 @@ between the target and the softmax activation function applied to the model's
prediction. As in the beginners tutorial, we use the stable formulation:
```python
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
```
Note that `tf.nn.softmax_cross_entropy_with_logits` internally applies the
@ -394,7 +395,8 @@ Feel free to go ahead and run this code, but it does 20,000 training iterations
and may take a while (possibly up to half an hour), depending on your processor.
```python
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv, y_))
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

View File

@ -171,7 +171,7 @@ First, the values from the `labels_placeholder` are converted to 64-bit integers
```python
labels = tf.to_int64(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits, labels, name='xentropy')
labels=labels, logits=logits, name='xentropy')
```
It then uses [`tf.reduce_mean`](../../../api_docs/python/math_ops.md#reduce_mean)

View File

@ -141,25 +141,26 @@ class SparseXentTest(test.TestCase):
with self.test_session(use_gpu=True):
with self.assertRaisesRegexp(ValueError, ".*Rank mismatch:*"):
nn_ops.sparse_softmax_cross_entropy_with_logits(
[[0., 1.], [2., 3.], [2., 3.]], [[0, 2]])
labels=[[0, 2]], logits=[[0., 1.], [2., 3.], [2., 3.]])
def testScalar(self):
with self.test_session(use_gpu=True):
with self.assertRaisesRegexp(ValueError, ".*Logits cannot be scalars*"):
nn_ops.sparse_softmax_cross_entropy_with_logits(
constant_op.constant(1.0), constant_op.constant(0))
labels=constant_op.constant(0), logits=constant_op.constant(1.0))
def testLabelsPlaceholderScalar(self):
with self.test_session(use_gpu=True):
labels = array_ops.placeholder(np.int32)
y = nn_ops.sparse_softmax_cross_entropy_with_logits([[7.]], labels)
y = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=[[7.]])
with self.assertRaisesOpError("labels must be 1-D"):
y.eval(feed_dict={labels: 0})
def testVector(self):
with self.test_session(use_gpu=True):
loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
constant_op.constant([1.0]), constant_op.constant(0))
labels=constant_op.constant(0), logits=constant_op.constant([1.0]))
self.assertAllClose(0.0, loss.eval())
def testFloat(self):
@ -191,7 +192,8 @@ class SparseXentTest(test.TestCase):
shape=[3, 4],
dtype=dtypes.float64,
name="f")
x = nn_ops.sparse_softmax_cross_entropy_with_logits(f, l, name="xent")
x = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=l, logits=f, name="xent")
err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)
@ -201,7 +203,8 @@ class SparseXentTest(test.TestCase):
# manually reshape loss
np_loss = np.reshape(np_loss, np.array(labels).shape)
with self.test_session(use_gpu=True) as sess:
loss = nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels)
loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=features)
backprop = loss.op.inputs[0].op.outputs[1]
tf_loss, tf_backprop = sess.run([loss, backprop])
self.assertAllCloseAccordingToType(np_loss, tf_loss)
@ -225,7 +228,7 @@ class SparseXentTest(test.TestCase):
labels = array_ops.placeholder(dtypes.int32, shape=[None, 1])
logits = array_ops.placeholder(dtypes.float32, shape=[None, 3])
ce = nn_ops.sparse_softmax_cross_entropy_with_logits(
logits, array_ops.squeeze(labels))
labels=array_ops.squeeze(labels), logits=logits)
labels_v2 = np.zeros((1, 1), dtype=np.int32)
logits_v2 = np.random.randn(1, 3)
sess.run([ce], feed_dict={labels: labels_v2, logits: logits_v2})
@ -243,7 +246,7 @@ def _sparse_vs_dense_xent_benchmark_dense(labels, logits):
array_ops.stack([length]), 1.0, 0.0)
target = array_ops.reshape(target, array_ops.stack([-1, num_entries]))
crossent = nn_ops.softmax_cross_entropy_with_logits(
logits, target, name="SequenceLoss/CrossEntropy")
labels=target, logits=logits, name="SequenceLoss/CrossEntropy")
crossent_sum = math_ops.reduce_sum(crossent)
grads = gradients_impl.gradients([crossent_sum], [logits])[0]

View File

@ -57,7 +57,7 @@ class XentTest(test.TestCase):
np_loss, _ = self._npXent(np_features, np_labels, dim=dim)
with self.test_session(use_gpu=use_gpu) as sess:
loss = nn_ops.softmax_cross_entropy_with_logits(
np_features, np_labels, dim=dim)
labels=np_labels, logits=np_features, dim=dim)
tf_loss = sess.run(loss)
print("np_loss:", np_loss)
print("tf_loss:", tf_loss)
@ -166,7 +166,8 @@ class XentTest(test.TestCase):
shape=[3, 4],
dtype=dtypes.float64,
name="f")
x = nn_ops.softmax_cross_entropy_with_logits(f, l, name="xent")
x = nn_ops.softmax_cross_entropy_with_logits(labels=l, logits=f,
name="xent")
err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)

View File

@ -267,7 +267,7 @@ class MiniMNISTTest(test.TestCase):
dtype=dtypes.float64,
name="labels")
cost = nn_ops.softmax_cross_entropy_with_logits(
logits, labels, name="cost")
labels=labels, logits=logits, name="cost")
# Test the gradients.
err = gradient_checker.compute_gradient_error(

View File

@ -559,7 +559,8 @@ def sigmoid_cross_entropy(
multi_class_labels = (multi_class_labels * (1 - label_smoothing) +
0.5 * label_smoothing)
losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels,
losses = nn.sigmoid_cross_entropy_with_logits(labels=multi_class_labels,
logits=logits,
name="xentropy")
return compute_weighted_loss(losses, weights, scope, loss_collection)
@ -613,7 +614,8 @@ def softmax_cross_entropy(
smooth_negatives = label_smoothing / num_classes
onehot_labels = onehot_labels * smooth_positives + smooth_negatives
losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels,
losses = nn.softmax_cross_entropy_with_logits(labels=onehot_labels,
logits=logits,
name="xentropy")
return compute_weighted_loss(losses, weights, scope, loss_collection)
@ -653,7 +655,8 @@ def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None,
[logits, labels, weights]) as scope:
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels,
losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
logits=logits,
name="xentropy")
# Reshape losses to [batch_size, 1] to be consistent with weights.
losses = array_ops.reshape(losses, shape=[array_ops.shape(losses)[0], 1])

View File

@ -96,7 +96,9 @@ def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
return result
def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
def sigmoid_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
labels=None, logits=None,
name=None):
"""Computes sigmoid cross entropy given `logits`.
Measures the probability error in discrete classification tasks in which each
@ -104,7 +106,7 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
perform multilabel classification where a picture can contain both an elephant
and a dog at the same time.
For brevity, let `x = logits`, `z = targets`. The logistic loss is
For brevity, let `x = logits`, `z = labels`. The logistic loss is
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
@ -124,11 +126,12 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
max(x, 0) - x * z + log(1 + exp(-abs(x)))
`logits` and `targets` must have the same type and shape.
`logits` and `labels` must have the same type and shape.
Args:
_sentinel: Used to prevent positional parameters. Internal, do not use.
labels: A `Tensor` of the same type and shape as `logits`.
logits: A `Tensor` of type `float32` or `float64`.
targets: A `Tensor` of the same type and shape as `logits`.
name: A name for the operation (optional).
Returns:
@ -136,16 +139,21 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
logistic losses.
Raises:
ValueError: If `logits` and `targets` do not have the same shape.
ValueError: If `logits` and `labels` do not have the same shape.
"""
with ops.name_scope(name, "logistic_loss", [logits, targets]) as name:
# pylint: disable=protected-access
nn_ops._ensure_xent_args("sigmoid_cross_entropy_with_logits",
_sentinel, labels, logits)
# pylint: enable=protected-access
with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
logits = ops.convert_to_tensor(logits, name="logits")
targets = ops.convert_to_tensor(targets, name="targets")
labels = ops.convert_to_tensor(labels, name="labels")
try:
targets.get_shape().merge_with(logits.get_shape())
labels.get_shape().merge_with(logits.get_shape())
except ValueError:
raise ValueError("logits and targets must have the same shape (%s vs %s)"
% (logits.get_shape(), targets.get_shape()))
raise ValueError("logits and labels must have the same shape (%s vs %s)"
% (logits.get_shape(), labels.get_shape()))
# The logistic loss formula from above is
# x - x * z + log(1 + exp(-x))
@ -159,7 +167,7 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
cond = (logits >= zeros)
relu_logits = array_ops.where(cond, logits, zeros)
neg_abs_logits = array_ops.where(cond, -logits, logits)
return math_ops.add(relu_logits - logits * targets,
return math_ops.add(relu_logits - logits * labels,
math_ops.log1p(math_ops.exp(neg_abs_logits)),
name=name)
@ -1095,7 +1103,7 @@ def nce_loss(weights,
partition_strategy=partition_strategy,
name=name)
sampled_losses = sigmoid_cross_entropy_with_logits(
logits, labels, name="sampled_losses")
labels=labels, logits=logits, name="sampled_losses")
# sampled_losses is batch_size x {true_loss, sampled_losses...}
# We sum out true and sampled losses.
return _sum_rows(sampled_losses)
@ -1170,6 +1178,7 @@ def sampled_softmax_loss(weights,
remove_accidental_hits=remove_accidental_hits,
partition_strategy=partition_strategy,
name=name)
sampled_losses = nn_ops.softmax_cross_entropy_with_logits(logits, labels)
sampled_losses = nn_ops.softmax_cross_entropy_with_logits(labels=labels,
logits=logits)
# sampled_losses is a [batch_size] tensor.
return sampled_losses

View File

@ -14,7 +14,6 @@
# ==============================================================================
"""Wrappers for primitive Neural Net (NN) Operations."""
# pylint: disable=invalid-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -1047,7 +1046,7 @@ def conv2d_transpose(value,
raise ValueError("data_format has to be either NCHW or NHWC.")
value = ops.convert_to_tensor(value, name="value")
filter = ops.convert_to_tensor(filter, name="filter")
axis = 3 if data_format=="NHWC" else 1
axis = 3 if data_format == "NHWC" else 1
if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[3]):
raise ValueError("input channels does not match filter's input channels, "
"{} != {}".format(value.get_shape()[3], filter.get_shape(
@ -1528,7 +1527,18 @@ def log_softmax(logits, dim=-1, name=None):
return _softmax(logits, gen_nn_ops._log_softmax, dim, name)
def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
def _ensure_xent_args(name, sentinel, labels, logits):
# Make sure that all arguments were passed as named arguments.
if sentinel is not None:
raise ValueError("Only call `%s` with "
"named arguments (labels=..., logits=..., ...)" % name)
if labels is None or logits is None:
raise ValueError("Both labels and logits must be provided.")
def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
labels=None, logits=None,
dim=-1, name=None):
"""Computes softmax cross entropy between `logits` and `labels`.
Measures the probability error in discrete classification tasks in which the
@ -1551,9 +1561,13 @@ def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
`logits` and `labels` must have the same shape `[batch_size, num_classes]`
and the same dtype (either `float16`, `float32`, or `float64`).
**Note that to avoid confusion, it is required to pass only named arguments to
this function.**
Args:
logits: Unscaled log probabilities.
_sentinel: Used to prevent positional parameters. Internal, do not use.
labels: Each row `labels[i]` must be a valid probability distribution.
logits: Unscaled log probabilities.
dim: The class dimension. Defaulted to -1 which is the last dimension.
name: A name for the operation (optional).
@ -1561,6 +1575,9 @@ def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
softmax cross entropy loss.
"""
_ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel,
labels, logits)
# TODO(pcmurray) Raise an error when the labels do not sum to 1. Note: This
# could break users who call this with bad labels, but disregard the bad
# results.
@ -1569,7 +1586,7 @@ def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
labels = ops.convert_to_tensor(labels)
precise_logits = math_ops.cast(logits, dtypes.float32) if (
logits.dtype == dtypes.float16) else logits
# Labels and logits must be of the same type
# labels and logits must be of the same type
labels = math_ops.cast(labels, precise_logits.dtype)
input_rank = array_ops.rank(precise_logits)
# For shape inference.
@ -1618,7 +1635,9 @@ def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
return cost
def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
labels=None, logits=None,
name=None):
"""Computes sparse softmax cross entropy between `logits` and `labels`.
Measures the probability error in discrete classification tasks in which the
@ -1640,14 +1659,18 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
A common use case is to have logits of shape `[batch_size, num_classes]` and
labels of shape `[batch_size]`. But higher dimensions are supported.
**Note that to avoid confusion, it is required to pass only named arguments to
this function.**
Args:
logits: Unscaled log probabilities of rank `r` and shape
`[d_0, d_1, ..., d_{r-2}, num_classes]` and dtype `float32` or `float64`.
_sentinel: Used to prevent positional parameters. Internal, do not use.
labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-2}]` and dtype `int32` or
`int64`. Each entry in `labels` must be an index in `[0, num_classes)`.
Other values will raise an exception when this op is run on CPU, and
return `NaN` for corresponding corresponding loss and gradient rows
on GPU.
logits: Unscaled log probabilities of rank `r` and shape
`[d_0, d_1, ..., d_{r-2}, num_classes]` and dtype `float32` or `float64`.
name: A name for the operation (optional).
Returns:
@ -1658,6 +1681,9 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
ValueError: If logits are scalars (need to have rank >= 1) or if the rank
of the labels is not equal to the rank of the labels minus one.
"""
_ensure_xent_args("sparse_softmax_cross_entropy_with_logits", _sentinel,
labels, logits)
# TODO(pcmurray) Raise an error when the label is not an index in
# [0, num_classes). Note: This could break users who call this with bad
# labels, but disregard the bad results.
@ -1679,8 +1705,8 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
if logits.get_shape().ndims is not None and (
labels_static_shape.ndims is not None and
labels_static_shape.ndims != logits.get_shape().ndims - 1):
raise ValueError("Rank mismatch: Rank of labels (received %s) should equal "
"rank of logits minus 1 (received %s)." %
raise ValueError("Rank mismatch: Rank of labels (received %s) should "
"equal rank of logits minus 1 (received %s)." %
(labels_static_shape.ndims, logits.get_shape().ndims))
# Check if no reshapes are required.
if logits.get_shape().ndims == 2:
@ -1857,8 +1883,7 @@ def xw_plus_b_v1(x, weights, biases, name=None): # pylint: disable=invalid-name
return bias_add_v1(mm, biases, name=name)
# pylint: disable=invalid-name
def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name
"""Computes dropout.
With probability `keep_prob`, outputs the input element scaled up by
@ -2082,5 +2107,3 @@ def erosion2d(value, kernel, strides, rates, padding, name=None):
rates=rates,
padding=padding,
name=name))
# pylint: enable=invalid-name

View File

@ -57,7 +57,7 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
with self.test_session():
logits, targets, _ = self._Inputs()
loss = nn_impl.sigmoid_cross_entropy_with_logits(
logits, targets, name="mylogistic")
labels=targets, logits=logits, name="mylogistic")
self.assertEqual("mylogistic", loss.op.name)
def testLogisticOutput(self):
@ -65,7 +65,8 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
for dtype in [dtypes.float32, dtypes.float16]:
with self.test_session(use_gpu=use_gpu):
logits, targets, losses = self._Inputs(dtype=dtype)
loss = nn_impl.sigmoid_cross_entropy_with_logits(logits, targets)
loss = nn_impl.sigmoid_cross_entropy_with_logits(
labels=targets, logits=logits)
np_loss = np.array(losses).astype(np.float32)
tf_loss = loss.eval()
self.assertAllClose(np_loss, tf_loss, atol=0.001)
@ -75,7 +76,8 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
for dtype in [dtypes.float32, dtypes.float16]:
with self.test_session(use_gpu=use_gpu):
logits, targets, losses = self._Inputs(dtype=dtype, sizes=[2, 2, 2])
loss = nn_impl.sigmoid_cross_entropy_with_logits(logits, targets)
loss = nn_impl.sigmoid_cross_entropy_with_logits(
labels=targets, logits=logits)
np_loss = np.array(losses).astype(np.float32)
tf_loss = loss.eval()
self.assertAllClose(np_loss, tf_loss, atol=0.001)
@ -84,7 +86,8 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
sizes = [4, 2]
with self.test_session():
logits, targets, _ = self._Inputs(sizes=sizes)
loss = nn_impl.sigmoid_cross_entropy_with_logits(logits, targets)
loss = nn_impl.sigmoid_cross_entropy_with_logits(
labels=targets, logits=logits)
err = gradient_checker.compute_gradient_error(logits, sizes, loss, sizes)
print("logistic loss gradient err = ", err)
self.assertLess(err, 1e-7)
@ -93,13 +96,15 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
with self.test_session():
logits = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
targets = constant_op.constant([0.0, 1.0], dtype=dtypes.float64)
loss = nn_impl.sigmoid_cross_entropy_with_logits(logits, targets)
loss = nn_impl.sigmoid_cross_entropy_with_logits(
labels=targets, logits=logits)
grads = gradients_impl.gradients(loss, logits)[0].eval()
self.assertAllClose(grads, [0.5, -0.5])
def testShapeError(self):
with self.assertRaisesRegexp(ValueError, "must have the same shape"):
nn_impl.sigmoid_cross_entropy_with_logits([[2, 1]], [1, 2, 3])
nn_impl.sigmoid_cross_entropy_with_logits(labels=[1, 2, 3],
logits=[[2, 1]])
class WeightedCrossEntropyTest(test.TestCase):
@ -128,15 +133,15 @@ class WeightedCrossEntropyTest(test.TestCase):
with self.test_session():
logits, targets, pos_weight, _ = self._Inputs()
loss = nn_impl.weighted_cross_entropy_with_logits(
targets, logits, pos_weight, name="mybce")
targets=targets, logits=logits, pos_weight=pos_weight, name="mybce")
self.assertEqual("mybce", loss.op.name)
def testOutput(self):
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu):
logits, targets, pos_weight, losses = self._Inputs(dtype=dtypes.float32)
loss = nn_impl.weighted_cross_entropy_with_logits(targets, logits,
pos_weight)
loss = nn_impl.weighted_cross_entropy_with_logits(
targets=targets, logits=logits, pos_weight=pos_weight)
np_loss = np.array(losses).astype(np.float32)
tf_loss = loss.eval()
self.assertAllClose(np_loss, tf_loss, atol=0.001)
@ -146,8 +151,8 @@ class WeightedCrossEntropyTest(test.TestCase):
with self.test_session(use_gpu=use_gpu):
logits, targets, pos_weight, losses = self._Inputs(
dtype=dtypes.float32, sizes=[2, 2, 2])
loss = nn_impl.weighted_cross_entropy_with_logits(targets, logits,
pos_weight)
loss = nn_impl.weighted_cross_entropy_with_logits(
targets=targets, logits=logits, pos_weight=pos_weight)
np_loss = np.array(losses).astype(np.float32)
tf_loss = loss.eval()
self.assertAllClose(np_loss, tf_loss, atol=0.001)
@ -156,15 +161,16 @@ class WeightedCrossEntropyTest(test.TestCase):
sizes = [4, 2]
with self.test_session():
logits, targets, pos_weight, _ = self._Inputs(sizes=sizes)
loss = nn_impl.weighted_cross_entropy_with_logits(targets, logits,
pos_weight)
loss = nn_impl.weighted_cross_entropy_with_logits(
targets=targets, logits=logits, pos_weight=pos_weight)
err = gradient_checker.compute_gradient_error(logits, sizes, loss, sizes)
print("logistic loss gradient err = ", err)
self.assertLess(err, 1e-7)
def testShapeError(self):
with self.assertRaisesRegexp(ValueError, "must have the same shape"):
nn_impl.weighted_cross_entropy_with_logits([1, 2, 3], [[2, 1]], 2.0)
nn_impl.weighted_cross_entropy_with_logits(
targets=[1, 2, 3], logits=[[2, 1]], pos_weight=2.0)
if __name__ == "__main__":

View File

@ -1614,7 +1614,7 @@ class MetaGraphTest(test.TestCase):
concated, array_ops.stack([batch_size, 10]), 1.0, 0.0)
logits = ops_lib.get_collection("logits")[0]
cross_entropy = nn_ops.softmax_cross_entropy_with_logits(
logits, onehot_labels, name="xentropy")
labels=onehot_labels, logits=logits, name="xentropy")
loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean")
summary.scalar("loss", loss)
@ -1698,7 +1698,8 @@ class MetaGraphTest(test.TestCase):
bias = variables.Variable(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(logit, label, name="cost")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
logits=logit, name="cost")
adam.AdamOptimizer().minimize(cost, name="optimize")
saver = saver_module.Saver()
sess.run(variables.global_variables_initializer())
@ -1726,7 +1727,8 @@ class MetaGraphTest(test.TestCase):
bias = variables.Variable(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(logit, label)
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
logits=logit)
adam.AdamOptimizer().minimize(cost, name="optimize")
meta_graph_def = saver_module.export_meta_graph()
@ -1758,7 +1760,8 @@ class MetaGraphTest(test.TestCase):
bias = variables.Variable(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(logit, label)
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
logits=logit)
adam.AdamOptimizer().minimize(cost, name="optimize")
meta_graph_def = saver_module.export_meta_graph(clear_devices=True)
graph_io.write_graph(meta_graph_def, "/tmp", "meta_graph.pbtxt")

View File

@ -1206,7 +1206,7 @@
"# Training computation: logits + cross-entropy loss.\n",
"logits = model(train_data_node, True)\n",
"loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(\n",
" logits, train_labels_node))\n",
" labels=train_labels_node, logits=logits))\n",
"\n",
"# L2 regularization for the fully connected parameters.\n",
"regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) +\n",