Change weight broadcasting for losses.sparse_softmax_cross_entropy to be the same as sparse_softmax_cross_entropy_with_logits. This means allowing greater than rank 2 inputs.

Also, fix some pydoc.
Change: 144898465
This commit is contained in:
A. Unique TensorFlower 2017-01-18 16:33:20 -08:00 committed by TensorFlower Gardener
parent d8fd68242e
commit 91a6f2f4d6
7 changed files with 361 additions and 97 deletions

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import confusion_matrix
from tensorflow.python.ops import math_ops
@ -215,5 +216,239 @@ class ConfusionMatrixTest(test.TestCase):
self.assertEqual(tf_cm.dtype, np.int64)
class RemoveSqueezableDimensionsTest(test.TestCase):
def testBothScalarShape(self):
label_values = 1.0
prediction_values = 0.0
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values))
labels_placeholder = array_ops.placeholder(dtype=dtypes.float32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.float32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
with self.test_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
def testSameShape(self):
label_values = np.ones(shape=(2, 3, 1))
prediction_values = np.zeros_like(label_values)
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
with self.test_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
def testSameShapeExpectedRankDiff0(self):
label_values = np.ones(shape=(2, 3, 1))
prediction_values = np.zeros_like(label_values)
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values, expected_rank_diff=0))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder, expected_rank_diff=0))
with self.test_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
def testSqueezableLabels(self):
label_values = np.ones(shape=(2, 3, 1))
prediction_values = np.zeros(shape=(2, 3))
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
expected_label_values = np.reshape(label_values, newshape=(2, 3))
with self.test_session():
self.assertAllEqual(expected_label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
expected_label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
def testSqueezableLabelsExpectedRankDiffPlus1(self):
label_values = np.ones(shape=(2, 3, 1))
prediction_values = np.zeros(shape=(2, 3, 5))
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values, expected_rank_diff=1))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder, expected_rank_diff=1))
expected_label_values = np.reshape(label_values, newshape=(2, 3))
with self.test_session():
self.assertAllEqual(expected_label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
expected_label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
def testSqueezablePredictions(self):
label_values = np.ones(shape=(2, 3))
prediction_values = np.zeros(shape=(2, 3, 1))
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
with self.test_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(expected_prediction_values, static_predictions.eval())
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
expected_prediction_values,
dynamic_predictions.eval(feed_dict=feed_dict))
def testSqueezablePredictionsExpectedRankDiffMinus1(self):
label_values = np.ones(shape=(2, 3, 5))
prediction_values = np.zeros(shape=(2, 3, 1))
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values, expected_rank_diff=-1))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder, expected_rank_diff=-1))
expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
with self.test_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(expected_prediction_values, static_predictions.eval())
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
expected_prediction_values,
dynamic_predictions.eval(feed_dict=feed_dict))
def testUnsqueezableLabels(self):
label_values = np.ones(shape=(2, 3, 2))
prediction_values = np.zeros(shape=(2, 3))
with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"):
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values)
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
with self.test_session():
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
"Tried to explicitly squeeze dimension 2"):
dynamic_labels.eval(feed_dict=feed_dict)
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
def testUnsqueezablePredictions(self):
label_values = np.ones(shape=(2, 3))
prediction_values = np.zeros(shape=(2, 3, 2))
with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"):
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values)
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
with self.test_session():
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict))
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
"Tried to explicitly squeeze dimension 2"):
dynamic_predictions.eval(feed_dict=feed_dict)
if __name__ == "__main__":
test.main()

View File

@ -303,7 +303,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
weights = 2.3
with self.test_session():
loss = losses.sparse_softmax_cross_entropy(
labels, logits, constant_op.constant(weights, shape=(1, 1)))
labels, logits, constant_op.constant((weights,)))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
def testNonZeroLossWithPlaceholderForWeights(self):
@ -432,7 +432,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
labels = constant_op.constant([[0, 1], [2, 3]])
weights = constant_op.constant(1.2)
with self.assertRaises(errors_impl.InvalidArgumentError):
with self.assertRaisesRegexp(ValueError, 'dimension'):
losses.sparse_softmax_cross_entropy(
labels, logits, weights=weights).eval()
@ -1241,6 +1241,9 @@ class ComputeWeightedLossTest(test.TestCase):
def testInvalid1x2Weight(self):
self._test_invalid_weights((17.0, 3.0,),)
def testInvalidScalar1DWeight(self):
self._test_invalid_weights((17.0,),)
def _test_valid_weights(self, weights):
with ops.Graph().as_default():
self.assertEqual(0, len(util.get_losses()))

View File

@ -32,8 +32,19 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
def remove_squeezable_dimensions(labels, predictions, name=None):
"""Squeeze last dim if ranks of `predictions` and `labels` differ by 1.
def remove_squeezable_dimensions(
labels, predictions, expected_rank_diff=0, name=None):
"""Squeeze last dim if ranks differ from expected by exactly 1.
In the common case where we expect shapes to match, `expected_rank_diff`
defaults to 0, and we squeeze the last dimension of the larger rank if they
differ by 1.
But, for example, if `labels` contains class IDs and `predictions` contains 1
probability per class, we expect `predictions` to have 1 more dimension than
`labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze
`labels` if `rank(predictions) - rank(labels) == 0`, and
`predictions` if `rank(predictions) - rank(labels) == 2`.
This will use static shape if available. Otherwise, it will add graph
operations, which could result in a performance hit.
@ -41,6 +52,7 @@ def remove_squeezable_dimensions(labels, predictions, name=None):
Args:
labels: Label values, a `Tensor` whose dimensions match `predictions`.
predictions: Predicted values, a `Tensor` of arbitrary dimensions.
expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
name: Name of the op.
Returns:
@ -57,10 +69,10 @@ def remove_squeezable_dimensions(labels, predictions, name=None):
if (labels_rank is not None) and (predictions_rank is not None):
# Use static rank.
rank_diff = predictions_rank - labels_rank
if rank_diff == -1:
labels = array_ops.squeeze(labels, [-1])
elif rank_diff == 1:
if rank_diff == expected_rank_diff + 1:
predictions = array_ops.squeeze(predictions, [-1])
elif rank_diff == expected_rank_diff - 1:
labels = array_ops.squeeze(labels, [-1])
return labels, predictions
# Use dynamic rank.
@ -68,13 +80,13 @@ def remove_squeezable_dimensions(labels, predictions, name=None):
if (predictions_rank is None) or (
predictions_shape.dims[-1].is_compatible_with(1)):
predictions = control_flow_ops.cond(
math_ops.equal(1, rank_diff),
math_ops.equal(expected_rank_diff + 1, rank_diff),
lambda: array_ops.squeeze(predictions, [-1]),
lambda: predictions)
if (labels_rank is None) or (
labels_shape.dims[-1].is_compatible_with(1)):
labels = control_flow_ops.cond(
math_ops.equal(-1, rank_diff),
math_ops.equal(expected_rank_diff - 1, rank_diff),
lambda: array_ops.squeeze(labels, [-1]),
lambda: labels)
return labels, predictions

View File

@ -22,6 +22,8 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:confusion_matrix",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",

View File

@ -20,6 +20,8 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import confusion_matrix
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
@ -130,16 +132,11 @@ def compute_weighted_loss(
losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES):
"""Computes the weighted loss.
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
not work as advertised, you'll wind up with weighted sum instead of weighted
mean for any but the last dimension. This will be cleaned up soon, so please
do not rely on the current behavior for anything but the shapes documented for
`weights` below.
Args:
losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
weights: `Tensor` of shape `[]`, `[batch_size]` or
`[batch_size, d1, ... dK]`, where K < N.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`losses`, and must be broadcastable to `losses` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
scope: the scope for the operations performed in computing the loss.
loss_collection: the loss will be added to these collections.
@ -180,17 +177,12 @@ def absolute_difference(
measurable element of `predictions` is scaled by the corresponding value of
`weights`.
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
not work as advertised, you'll wind up with weighted sum instead of weighted
mean for any but the last dimension. This will be cleaned up soon, so please
do not rely on the current behavior for anything but the shapes documented for
`weights` below.
Args:
labels: The ground truth output tensor, same dimensions as 'predictions'.
predictions: The predicted outputs.
weights: Coefficients for the loss a scalar, a tensor of shape
`[batch_size]` or a tensor whose shape matches `predictions`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
@ -218,18 +210,13 @@ def cosine_distance(
Note that the function assumes that `predictions` and `labels` are already
unit-normalized.
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
not work as advertised, you'll wind up with weighted sum instead of weighted
mean for any but the last dimension. This will be cleaned up soon, so please
do not rely on the current behavior for anything but the shapes documented for
`weights` below.
Args:
labels: `Tensor` whose shape matches 'predictions'
predictions: An arbitrary matrix.
dim: The dimension along which the cosine distance is computed.
weights: Coefficients for the loss a scalar, a tensor of shape
`[batch_size]` or a tensor whose shape matches `predictions`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
@ -257,18 +244,13 @@ def hinge_loss(labels, logits, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES):
"""Adds a hinge loss to the training procedure.
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
not work as advertised, you'll wind up with weighted sum instead of weighted
mean for any but the last dimension. This will be cleaned up soon, so please
do not rely on the current behavior for anything but the shapes documented for
`weights` below.
Args:
labels: The ground truth output tensor. Its shape should match the shape of
logits. The values of the tensor are expected to be 0.0 or 1.0.
logits: The logits, a float tensor.
weights: Coefficients for the loss a scalar, a tensor of shape
`[batch_size]` or a tensor whose shape matches `predictions`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which the loss will be added.
@ -302,17 +284,12 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
measurable element of `predictions` is scaled by the corresponding value of
`weights`.
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
not work as advertised, you'll wind up with weighted sum instead of weighted
mean for any but the last dimension. This will be cleaned up soon, so please
do not rely on the current behavior for anything but the shapes documented for
`weights` below.
Args:
labels: The ground truth output tensor, same dimensions as 'predictions'.
predictions: The predicted outputs.
weights: Coefficients for the loss a scalar, a tensor of shape
`[batch_size]` or a tensor whose shape matches `predictions`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
epsilon: A small increment to add to avoid taking a log of zero.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which the loss will be added.
@ -427,17 +404,12 @@ def mean_squared_error(labels, predictions, weights=1.0, scope=None,
measurable element of `predictions` is scaled by the corresponding value of
`weights`.
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
not work as advertised, you'll wind up with weighted sum instead of weighted
mean for any but the last dimension. This will be cleaned up soon, so please
do not rely on the current behavior for anything but the shapes documented for
`weights` below.
Args:
labels: The ground truth output tensor, same dimensions as 'predictions'.
predictions: The predicted outputs.
weights: Coefficients for the loss a scalar, a tensor of shape
`[batch_size]` or a tensor whose shape matches `predictions`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which the loss will be added.
@ -467,12 +439,6 @@ def sigmoid_cross_entropy(
tensor of shape `[batch_size]`, then the loss weights apply to each
corresponding sample.
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
not work as advertised, you'll wind up with weighted sum instead of weighted
mean for any but the last dimension. This will be cleaned up soon, so please
do not rely on the current behavior for anything but the shapes documented for
`weights` below.
If `label_smoothing` is nonzero, smooth the labels towards 1/2:
new_multiclass_labels = multiclass_labels * (1 - label_smoothing)
@ -482,8 +448,9 @@ def sigmoid_cross_entropy(
multi_class_labels: `[batch_size, num_classes]` target integer labels in
`(0, 1)`.
logits: `[batch_size, num_classes]` logits outputs of the network.
weights: Coefficients for the loss. This must be of shape `[]`,
`[batch_size]` or `[batch_size, num_classes]`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
label_smoothing: If greater than `0` then smooth the labels.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which the loss will be added.
@ -522,12 +489,6 @@ def softmax_cross_entropy(
tensor of shape `[batch_size]`, then the loss weights apply to each
corresponding sample.
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
not work as advertised, you'll wind up with weighted sum instead of weighted
mean for any but the last dimension. This will be cleaned up soon, so please
do not rely on the current behavior for anything but the shapes documented for
`weights` below.
If `label_smoothing` is nonzero, smooth the labels towards 1/num_classes:
new_onehot_labels = onehot_labels * (1 - label_smoothing)
+ label_smoothing / num_classes
@ -535,8 +496,10 @@ def softmax_cross_entropy(
Args:
onehot_labels: `[batch_size, num_classes]` target one-hot-encoded labels.
logits: [batch_size, num_classes] logits outputs of the network .
weights: Coefficients for the loss. This must be of shape `[]`,
`[batch_size]` or `[batch_size, num_classes]`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`onehot_labels`, and must be broadcastable to `onehot_labels` (i.e., all
dimensions must be either `1`, or the same as the corresponding `losses`
dimension).
label_smoothing: If greater than 0 then smooth the labels.
scope: the scope for the operations performed in computing the loss.
loss_collection: collection to which the loss will be added.
@ -567,6 +530,57 @@ def softmax_cross_entropy(
return compute_weighted_loss(losses, weights, scope, loss_collection)
# TODO(ptucker): Merge this with similar method in metrics_impl.
def _remove_squeezable_dimensions(
labels, predictions, weights=None, expected_rank_diff=0):
"""Internal version of _remove_squeezable_dimensions which handles weights.
Squeezes `predictions` and `labels` if their ranks differ from expected by
exactly 1.
Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
This will use static shape if available. Otherwise, it will add graph
operations, which could result in a performance hit.
Args:
labels: Label values, a `Tensor` whose dimensions match `predictions`.
predictions: Predicted values, a `Tensor` of arbitrary dimensions.
weights: Optional weight `Tensor`. It will be squeezed if it's not scalar,
and its rank is 1 more than the new rank of `labels`.
expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
Returns:
Tuple of `predictions`, `labels` and `weights`, possibly with the last
dimension squeezed.
"""
labels, predictions = confusion_matrix.remove_squeezable_dimensions(
labels, predictions, expected_rank_diff=expected_rank_diff)
if weights is not None:
weights = ops.convert_to_tensor(weights)
labels_rank = labels.get_shape().ndims
weights_shape = weights.get_shape()
weights_rank = weights_shape.ndims
if (labels_rank is not None) and (weights_rank is not None):
# Use static rank.
rank_diff = weights_rank - labels_rank
if rank_diff == 1:
weights = array_ops.squeeze(weights, [-1])
return labels, predictions, weights
# Use dynamic rank.
rank_diff = array_ops.rank(weights) - array_ops.rank(labels)
if (weights_rank is None) or (
weights_shape.dims[-1].is_compatible_with(1)):
weights = control_flow_ops.cond(
math_ops.equal(1, rank_diff),
lambda: array_ops.squeeze(weights, [-1]),
lambda: weights)
return labels, predictions, weights
def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES):
"""Cross-entropy loss using `tf.nn.sparse_softmax_cross_entropy_with_logits`.
@ -576,18 +590,16 @@ def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None,
tensor of shape [`batch_size`], then the loss weights apply to each
corresponding sample.
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
not work as advertised, you'll wind up with weighted sum instead of weighted
mean for any but the last dimension. This will be cleaned up soon, so please
do not rely on the current behavior for anything but the shapes documented for
`weights` below.
Args:
labels: [batch_size, 1] or [batch_size] target labels of dtype `int32` or
`int64` in the range `[0, num_classes)`.
logits: [batch_size, num_classes] logits outputs of the network .
weights: Coefficients for the loss. This must be scalar or of shape
`[batch_size, 1]`.
labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
`labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
must be an index in `[0, num_classes)`. Other values will raise an
exception when this op is run on CPU, and return `NaN` for corresponding
loss and gradient rows on GPU.
logits: Unscaled log probabilities of shape
`[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`.
weights: Coefficients for the loss. This must be scalar or of same rank as
`labels`
scope: the scope for the operations performed in computing the loss.
loss_collection: collection to which the loss will be added.
@ -600,12 +612,12 @@ def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None,
"""
with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
(logits, labels, weights)) as scope:
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
# As documented above in Args, labels contain class IDs and logits contains
# 1 probability per class ID, so we expect rank(logits) - rank(labels) == 1;
# therefore, expected_rank_diff=1.
labels, logits, weights = _remove_squeezable_dimensions(
labels, logits, weights, expected_rank_diff=1)
losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
logits=logits,
name="xentropy")
# Reshape losses to [batch_size, 1] to be consistent with weights.
# TODO(ptucker): reshape to (-1, 1) more efficient?
losses = array_ops.reshape(losses, shape=[array_ops.shape(losses)[0], 1])
return compute_weighted_loss(losses, weights, scope, loss_collection)

View File

@ -87,7 +87,7 @@ def _remove_squeezable_dimensions(labels, predictions, weights):
weights = array_ops.squeeze(weights, [-1])
elif (weights_rank is None) or (
weights_shape.dims[-1].is_compatible_with(1)):
# Use dynamic rank
# Use dynamic rank.
weights = control_flow_ops.cond(
math_ops.equal(array_ops.rank(weights),
math_ops.add(array_ops.rank(predictions), 1)),

View File

@ -1663,13 +1663,13 @@ def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=
Args:
_sentinel: Used to prevent positional parameters. Internal, do not use.
labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-2}]` and dtype `int32` or
`int64`. Each entry in `labels` must be an index in `[0, num_classes)`.
Other values will raise an exception when this op is run on CPU, and
return `NaN` for corresponding corresponding loss and gradient rows
on GPU.
logits: Unscaled log probabilities of rank `r` and shape
`[d_0, d_1, ..., d_{r-2}, num_classes]` and dtype `float32` or `float64`.
labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
`labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
must be an index in `[0, num_classes)`. Other values will raise an
exception when this op is run on CPU, and return `NaN` for corresponding
loss and gradient rows on GPU.
logits: Unscaled log probabilities of shape
`[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`.
name: A name for the operation (optional).
Returns: