diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py index cf882091488..2d116df2ffb 100644 --- a/tensorflow/python/kernel_tests/confusion_matrix_test.py +++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops from tensorflow.python.ops import confusion_matrix from tensorflow.python.ops import math_ops @@ -215,5 +216,239 @@ class ConfusionMatrixTest(test.TestCase): self.assertEqual(tf_cm.dtype, np.int64) +class RemoveSqueezableDimensionsTest(test.TestCase): + + def testBothScalarShape(self): + label_values = 1.0 + prediction_values = 0.0 + static_labels, static_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + label_values, prediction_values)) + + labels_placeholder = array_ops.placeholder(dtype=dtypes.float32) + predictions_placeholder = array_ops.placeholder(dtype=dtypes.float32) + dynamic_labels, dynamic_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + labels_placeholder, predictions_placeholder)) + + with self.test_session(): + self.assertAllEqual(label_values, static_labels.eval()) + self.assertAllEqual(prediction_values, static_predictions.eval()) + feed_dict = { + labels_placeholder: label_values, + predictions_placeholder: prediction_values + } + self.assertAllEqual( + label_values, dynamic_labels.eval(feed_dict=feed_dict)) + self.assertAllEqual( + prediction_values, dynamic_predictions.eval(feed_dict=feed_dict)) + + def testSameShape(self): + label_values = np.ones(shape=(2, 3, 1)) + prediction_values = np.zeros_like(label_values) + static_labels, static_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + label_values, prediction_values)) + + labels_placeholder = array_ops.placeholder(dtype=dtypes.int32) + predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32) + dynamic_labels, dynamic_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + labels_placeholder, predictions_placeholder)) + + with self.test_session(): + self.assertAllEqual(label_values, static_labels.eval()) + self.assertAllEqual(prediction_values, static_predictions.eval()) + feed_dict = { + labels_placeholder: label_values, + predictions_placeholder: prediction_values + } + self.assertAllEqual( + label_values, dynamic_labels.eval(feed_dict=feed_dict)) + self.assertAllEqual( + prediction_values, dynamic_predictions.eval(feed_dict=feed_dict)) + + def testSameShapeExpectedRankDiff0(self): + label_values = np.ones(shape=(2, 3, 1)) + prediction_values = np.zeros_like(label_values) + static_labels, static_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + label_values, prediction_values, expected_rank_diff=0)) + + labels_placeholder = array_ops.placeholder(dtype=dtypes.int32) + predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32) + dynamic_labels, dynamic_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + labels_placeholder, predictions_placeholder, expected_rank_diff=0)) + + with self.test_session(): + self.assertAllEqual(label_values, static_labels.eval()) + self.assertAllEqual(prediction_values, static_predictions.eval()) + feed_dict = { + labels_placeholder: label_values, + predictions_placeholder: prediction_values + } + self.assertAllEqual( + label_values, dynamic_labels.eval(feed_dict=feed_dict)) + self.assertAllEqual( + prediction_values, dynamic_predictions.eval(feed_dict=feed_dict)) + + def testSqueezableLabels(self): + label_values = np.ones(shape=(2, 3, 1)) + prediction_values = np.zeros(shape=(2, 3)) + static_labels, static_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + label_values, prediction_values)) + + labels_placeholder = array_ops.placeholder(dtype=dtypes.int32) + predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32) + dynamic_labels, dynamic_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + labels_placeholder, predictions_placeholder)) + + expected_label_values = np.reshape(label_values, newshape=(2, 3)) + with self.test_session(): + self.assertAllEqual(expected_label_values, static_labels.eval()) + self.assertAllEqual(prediction_values, static_predictions.eval()) + feed_dict = { + labels_placeholder: label_values, + predictions_placeholder: prediction_values + } + self.assertAllEqual( + expected_label_values, dynamic_labels.eval(feed_dict=feed_dict)) + self.assertAllEqual( + prediction_values, dynamic_predictions.eval(feed_dict=feed_dict)) + + def testSqueezableLabelsExpectedRankDiffPlus1(self): + label_values = np.ones(shape=(2, 3, 1)) + prediction_values = np.zeros(shape=(2, 3, 5)) + static_labels, static_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + label_values, prediction_values, expected_rank_diff=1)) + + labels_placeholder = array_ops.placeholder(dtype=dtypes.int32) + predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32) + dynamic_labels, dynamic_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + labels_placeholder, predictions_placeholder, expected_rank_diff=1)) + + expected_label_values = np.reshape(label_values, newshape=(2, 3)) + with self.test_session(): + self.assertAllEqual(expected_label_values, static_labels.eval()) + self.assertAllEqual(prediction_values, static_predictions.eval()) + feed_dict = { + labels_placeholder: label_values, + predictions_placeholder: prediction_values + } + self.assertAllEqual( + expected_label_values, dynamic_labels.eval(feed_dict=feed_dict)) + self.assertAllEqual( + prediction_values, dynamic_predictions.eval(feed_dict=feed_dict)) + + def testSqueezablePredictions(self): + label_values = np.ones(shape=(2, 3)) + prediction_values = np.zeros(shape=(2, 3, 1)) + static_labels, static_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + label_values, prediction_values)) + + labels_placeholder = array_ops.placeholder(dtype=dtypes.int32) + predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32) + dynamic_labels, dynamic_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + labels_placeholder, predictions_placeholder)) + + expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3)) + with self.test_session(): + self.assertAllEqual(label_values, static_labels.eval()) + self.assertAllEqual(expected_prediction_values, static_predictions.eval()) + feed_dict = { + labels_placeholder: label_values, + predictions_placeholder: prediction_values + } + self.assertAllEqual( + label_values, dynamic_labels.eval(feed_dict=feed_dict)) + self.assertAllEqual( + expected_prediction_values, + dynamic_predictions.eval(feed_dict=feed_dict)) + + def testSqueezablePredictionsExpectedRankDiffMinus1(self): + label_values = np.ones(shape=(2, 3, 5)) + prediction_values = np.zeros(shape=(2, 3, 1)) + static_labels, static_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + label_values, prediction_values, expected_rank_diff=-1)) + + labels_placeholder = array_ops.placeholder(dtype=dtypes.int32) + predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32) + dynamic_labels, dynamic_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + labels_placeholder, predictions_placeholder, expected_rank_diff=-1)) + + expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3)) + with self.test_session(): + self.assertAllEqual(label_values, static_labels.eval()) + self.assertAllEqual(expected_prediction_values, static_predictions.eval()) + feed_dict = { + labels_placeholder: label_values, + predictions_placeholder: prediction_values + } + self.assertAllEqual( + label_values, dynamic_labels.eval(feed_dict=feed_dict)) + self.assertAllEqual( + expected_prediction_values, + dynamic_predictions.eval(feed_dict=feed_dict)) + + def testUnsqueezableLabels(self): + label_values = np.ones(shape=(2, 3, 2)) + prediction_values = np.zeros(shape=(2, 3)) + with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"): + confusion_matrix.remove_squeezable_dimensions( + label_values, prediction_values) + + labels_placeholder = array_ops.placeholder(dtype=dtypes.int32) + predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32) + dynamic_labels, dynamic_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + labels_placeholder, predictions_placeholder)) + + with self.test_session(): + feed_dict = { + labels_placeholder: label_values, + predictions_placeholder: prediction_values + } + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + "Tried to explicitly squeeze dimension 2"): + dynamic_labels.eval(feed_dict=feed_dict) + self.assertAllEqual( + prediction_values, dynamic_predictions.eval(feed_dict=feed_dict)) + + def testUnsqueezablePredictions(self): + label_values = np.ones(shape=(2, 3)) + prediction_values = np.zeros(shape=(2, 3, 2)) + with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"): + confusion_matrix.remove_squeezable_dimensions( + label_values, prediction_values) + + labels_placeholder = array_ops.placeholder(dtype=dtypes.int32) + predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32) + dynamic_labels, dynamic_predictions = ( + confusion_matrix.remove_squeezable_dimensions( + labels_placeholder, predictions_placeholder)) + + with self.test_session(): + feed_dict = { + labels_placeholder: label_values, + predictions_placeholder: prediction_values + } + self.assertAllEqual( + label_values, dynamic_labels.eval(feed_dict=feed_dict)) + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + "Tried to explicitly squeeze dimension 2"): + dynamic_predictions.eval(feed_dict=feed_dict) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index ec356373657..f3ae092b6f5 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -303,7 +303,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): weights = 2.3 with self.test_session(): loss = losses.sparse_softmax_cross_entropy( - labels, logits, constant_op.constant(weights, shape=(1, 1))) + labels, logits, constant_op.constant((weights,))) self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) def testNonZeroLossWithPlaceholderForWeights(self): @@ -432,7 +432,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): labels = constant_op.constant([[0, 1], [2, 3]]) weights = constant_op.constant(1.2) - with self.assertRaises(errors_impl.InvalidArgumentError): + with self.assertRaisesRegexp(ValueError, 'dimension'): losses.sparse_softmax_cross_entropy( labels, logits, weights=weights).eval() @@ -1241,6 +1241,9 @@ class ComputeWeightedLossTest(test.TestCase): def testInvalid1x2Weight(self): self._test_invalid_weights((17.0, 3.0,),) + def testInvalidScalar1DWeight(self): + self._test_invalid_weights((17.0,),) + def _test_valid_weights(self, weights): with ops.Graph().as_default(): self.assertEqual(0, len(util.get_losses())) diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py index 628853545e9..95247ea125f 100644 --- a/tensorflow/python/ops/confusion_matrix.py +++ b/tensorflow/python/ops/confusion_matrix.py @@ -32,8 +32,19 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops -def remove_squeezable_dimensions(labels, predictions, name=None): - """Squeeze last dim if ranks of `predictions` and `labels` differ by 1. +def remove_squeezable_dimensions( + labels, predictions, expected_rank_diff=0, name=None): + """Squeeze last dim if ranks differ from expected by exactly 1. + + In the common case where we expect shapes to match, `expected_rank_diff` + defaults to 0, and we squeeze the last dimension of the larger rank if they + differ by 1. + + But, for example, if `labels` contains class IDs and `predictions` contains 1 + probability per class, we expect `predictions` to have 1 more dimension than + `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze + `labels` if `rank(predictions) - rank(labels) == 0`, and + `predictions` if `rank(predictions) - rank(labels) == 2`. This will use static shape if available. Otherwise, it will add graph operations, which could result in a performance hit. @@ -41,6 +52,7 @@ def remove_squeezable_dimensions(labels, predictions, name=None): Args: labels: Label values, a `Tensor` whose dimensions match `predictions`. predictions: Predicted values, a `Tensor` of arbitrary dimensions. + expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`. name: Name of the op. Returns: @@ -57,10 +69,10 @@ def remove_squeezable_dimensions(labels, predictions, name=None): if (labels_rank is not None) and (predictions_rank is not None): # Use static rank. rank_diff = predictions_rank - labels_rank - if rank_diff == -1: - labels = array_ops.squeeze(labels, [-1]) - elif rank_diff == 1: + if rank_diff == expected_rank_diff + 1: predictions = array_ops.squeeze(predictions, [-1]) + elif rank_diff == expected_rank_diff - 1: + labels = array_ops.squeeze(labels, [-1]) return labels, predictions # Use dynamic rank. @@ -68,13 +80,13 @@ def remove_squeezable_dimensions(labels, predictions, name=None): if (predictions_rank is None) or ( predictions_shape.dims[-1].is_compatible_with(1)): predictions = control_flow_ops.cond( - math_ops.equal(1, rank_diff), + math_ops.equal(expected_rank_diff + 1, rank_diff), lambda: array_ops.squeeze(predictions, [-1]), lambda: predictions) if (labels_rank is None) or ( labels_shape.dims[-1].is_compatible_with(1)): labels = control_flow_ops.cond( - math_ops.equal(-1, rank_diff), + math_ops.equal(expected_rank_diff - 1, rank_diff), lambda: array_ops.squeeze(labels, [-1]), lambda: labels) return labels, predictions diff --git a/tensorflow/python/ops/losses/BUILD b/tensorflow/python/ops/losses/BUILD index 6bb46d5b615..c4ce11ce0fc 100644 --- a/tensorflow/python/ops/losses/BUILD +++ b/tensorflow/python/ops/losses/BUILD @@ -22,6 +22,8 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", + "//tensorflow/python:confusion_matrix", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:nn", diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index bcfefaba7c9..89daa9594a2 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -20,6 +20,8 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import confusion_matrix +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops @@ -130,16 +132,11 @@ def compute_weighted_loss( losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES): """Computes the weighted loss. - WARNING: `weights` also supports dimensions of 1, but the broadcasting does - not work as advertised, you'll wind up with weighted sum instead of weighted - mean for any but the last dimension. This will be cleaned up soon, so please - do not rely on the current behavior for anything but the shapes documented for - `weights` below. - Args: losses: `Tensor` of shape `[batch_size, d1, ... dN]`. - weights: `Tensor` of shape `[]`, `[batch_size]` or - `[batch_size, d1, ... dK]`, where K < N. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `losses`, and must be broadcastable to `losses` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). scope: the scope for the operations performed in computing the loss. loss_collection: the loss will be added to these collections. @@ -180,17 +177,12 @@ def absolute_difference( measurable element of `predictions` is scaled by the corresponding value of `weights`. - WARNING: `weights` also supports dimensions of 1, but the broadcasting does - not work as advertised, you'll wind up with weighted sum instead of weighted - mean for any but the last dimension. This will be cleaned up soon, so please - do not rely on the current behavior for anything but the shapes documented for - `weights` below. - Args: labels: The ground truth output tensor, same dimensions as 'predictions'. predictions: The predicted outputs. - weights: Coefficients for the loss a scalar, a tensor of shape - `[batch_size]` or a tensor whose shape matches `predictions`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. @@ -218,18 +210,13 @@ def cosine_distance( Note that the function assumes that `predictions` and `labels` are already unit-normalized. - WARNING: `weights` also supports dimensions of 1, but the broadcasting does - not work as advertised, you'll wind up with weighted sum instead of weighted - mean for any but the last dimension. This will be cleaned up soon, so please - do not rely on the current behavior for anything but the shapes documented for - `weights` below. - Args: labels: `Tensor` whose shape matches 'predictions' predictions: An arbitrary matrix. dim: The dimension along which the cosine distance is computed. - weights: Coefficients for the loss a scalar, a tensor of shape - `[batch_size]` or a tensor whose shape matches `predictions`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. @@ -257,18 +244,13 @@ def hinge_loss(labels, logits, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES): """Adds a hinge loss to the training procedure. - WARNING: `weights` also supports dimensions of 1, but the broadcasting does - not work as advertised, you'll wind up with weighted sum instead of weighted - mean for any but the last dimension. This will be cleaned up soon, so please - do not rely on the current behavior for anything but the shapes documented for - `weights` below. - Args: labels: The ground truth output tensor. Its shape should match the shape of logits. The values of the tensor are expected to be 0.0 or 1.0. logits: The logits, a float tensor. - weights: Coefficients for the loss a scalar, a tensor of shape - `[batch_size]` or a tensor whose shape matches `predictions`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which the loss will be added. @@ -302,17 +284,12 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, measurable element of `predictions` is scaled by the corresponding value of `weights`. - WARNING: `weights` also supports dimensions of 1, but the broadcasting does - not work as advertised, you'll wind up with weighted sum instead of weighted - mean for any but the last dimension. This will be cleaned up soon, so please - do not rely on the current behavior for anything but the shapes documented for - `weights` below. - Args: labels: The ground truth output tensor, same dimensions as 'predictions'. predictions: The predicted outputs. - weights: Coefficients for the loss a scalar, a tensor of shape - `[batch_size]` or a tensor whose shape matches `predictions`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). epsilon: A small increment to add to avoid taking a log of zero. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which the loss will be added. @@ -427,17 +404,12 @@ def mean_squared_error(labels, predictions, weights=1.0, scope=None, measurable element of `predictions` is scaled by the corresponding value of `weights`. - WARNING: `weights` also supports dimensions of 1, but the broadcasting does - not work as advertised, you'll wind up with weighted sum instead of weighted - mean for any but the last dimension. This will be cleaned up soon, so please - do not rely on the current behavior for anything but the shapes documented for - `weights` below. - Args: labels: The ground truth output tensor, same dimensions as 'predictions'. predictions: The predicted outputs. - weights: Coefficients for the loss a scalar, a tensor of shape - `[batch_size]` or a tensor whose shape matches `predictions`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which the loss will be added. @@ -467,12 +439,6 @@ def sigmoid_cross_entropy( tensor of shape `[batch_size]`, then the loss weights apply to each corresponding sample. - WARNING: `weights` also supports dimensions of 1, but the broadcasting does - not work as advertised, you'll wind up with weighted sum instead of weighted - mean for any but the last dimension. This will be cleaned up soon, so please - do not rely on the current behavior for anything but the shapes documented for - `weights` below. - If `label_smoothing` is nonzero, smooth the labels towards 1/2: new_multiclass_labels = multiclass_labels * (1 - label_smoothing) @@ -482,8 +448,9 @@ def sigmoid_cross_entropy( multi_class_labels: `[batch_size, num_classes]` target integer labels in `(0, 1)`. logits: `[batch_size, num_classes]` logits outputs of the network. - weights: Coefficients for the loss. This must be of shape `[]`, - `[batch_size]` or `[batch_size, num_classes]`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). label_smoothing: If greater than `0` then smooth the labels. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which the loss will be added. @@ -522,12 +489,6 @@ def softmax_cross_entropy( tensor of shape `[batch_size]`, then the loss weights apply to each corresponding sample. - WARNING: `weights` also supports dimensions of 1, but the broadcasting does - not work as advertised, you'll wind up with weighted sum instead of weighted - mean for any but the last dimension. This will be cleaned up soon, so please - do not rely on the current behavior for anything but the shapes documented for - `weights` below. - If `label_smoothing` is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes @@ -535,8 +496,10 @@ def softmax_cross_entropy( Args: onehot_labels: `[batch_size, num_classes]` target one-hot-encoded labels. logits: [batch_size, num_classes] logits outputs of the network . - weights: Coefficients for the loss. This must be of shape `[]`, - `[batch_size]` or `[batch_size, num_classes]`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `onehot_labels`, and must be broadcastable to `onehot_labels` (i.e., all + dimensions must be either `1`, or the same as the corresponding `losses` + dimension). label_smoothing: If greater than 0 then smooth the labels. scope: the scope for the operations performed in computing the loss. loss_collection: collection to which the loss will be added. @@ -567,6 +530,57 @@ def softmax_cross_entropy( return compute_weighted_loss(losses, weights, scope, loss_collection) +# TODO(ptucker): Merge this with similar method in metrics_impl. +def _remove_squeezable_dimensions( + labels, predictions, weights=None, expected_rank_diff=0): + """Internal version of _remove_squeezable_dimensions which handles weights. + + Squeezes `predictions` and `labels` if their ranks differ from expected by + exactly 1. + Squeezes `weights` if its rank is 1 more than the new rank of `predictions` + + This will use static shape if available. Otherwise, it will add graph + operations, which could result in a performance hit. + + Args: + labels: Label values, a `Tensor` whose dimensions match `predictions`. + predictions: Predicted values, a `Tensor` of arbitrary dimensions. + weights: Optional weight `Tensor`. It will be squeezed if it's not scalar, + and its rank is 1 more than the new rank of `labels`. + expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`. + + Returns: + Tuple of `predictions`, `labels` and `weights`, possibly with the last + dimension squeezed. + """ + labels, predictions = confusion_matrix.remove_squeezable_dimensions( + labels, predictions, expected_rank_diff=expected_rank_diff) + + if weights is not None: + weights = ops.convert_to_tensor(weights) + labels_rank = labels.get_shape().ndims + weights_shape = weights.get_shape() + weights_rank = weights_shape.ndims + + if (labels_rank is not None) and (weights_rank is not None): + # Use static rank. + rank_diff = weights_rank - labels_rank + if rank_diff == 1: + weights = array_ops.squeeze(weights, [-1]) + return labels, predictions, weights + + # Use dynamic rank. + rank_diff = array_ops.rank(weights) - array_ops.rank(labels) + if (weights_rank is None) or ( + weights_shape.dims[-1].is_compatible_with(1)): + weights = control_flow_ops.cond( + math_ops.equal(1, rank_diff), + lambda: array_ops.squeeze(weights, [-1]), + lambda: weights) + + return labels, predictions, weights + + def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES): """Cross-entropy loss using `tf.nn.sparse_softmax_cross_entropy_with_logits`. @@ -576,18 +590,16 @@ def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None, tensor of shape [`batch_size`], then the loss weights apply to each corresponding sample. - WARNING: `weights` also supports dimensions of 1, but the broadcasting does - not work as advertised, you'll wind up with weighted sum instead of weighted - mean for any but the last dimension. This will be cleaned up soon, so please - do not rely on the current behavior for anything but the shapes documented for - `weights` below. - Args: - labels: [batch_size, 1] or [batch_size] target labels of dtype `int32` or - `int64` in the range `[0, num_classes)`. - logits: [batch_size, num_classes] logits outputs of the network . - weights: Coefficients for the loss. This must be scalar or of shape - `[batch_size, 1]`. + labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of + `labels` and result) and dtype `int32` or `int64`. Each entry in `labels` + must be an index in `[0, num_classes)`. Other values will raise an + exception when this op is run on CPU, and return `NaN` for corresponding + loss and gradient rows on GPU. + logits: Unscaled log probabilities of shape + `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`. + weights: Coefficients for the loss. This must be scalar or of same rank as + `labels` scope: the scope for the operations performed in computing the loss. loss_collection: collection to which the loss will be added. @@ -600,12 +612,12 @@ def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None, """ with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss", (logits, labels, weights)) as scope: - labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]]) - + # As documented above in Args, labels contain class IDs and logits contains + # 1 probability per class ID, so we expect rank(logits) - rank(labels) == 1; + # therefore, expected_rank_diff=1. + labels, logits, weights = _remove_squeezable_dimensions( + labels, logits, weights, expected_rank_diff=1) losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name="xentropy") - # Reshape losses to [batch_size, 1] to be consistent with weights. - # TODO(ptucker): reshape to (-1, 1) more efficient? - losses = array_ops.reshape(losses, shape=[array_ops.shape(losses)[0], 1]) return compute_weighted_loss(losses, weights, scope, loss_collection) diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 6fe90340a45..0a109eb99b4 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -87,7 +87,7 @@ def _remove_squeezable_dimensions(labels, predictions, weights): weights = array_ops.squeeze(weights, [-1]) elif (weights_rank is None) or ( weights_shape.dims[-1].is_compatible_with(1)): - # Use dynamic rank + # Use dynamic rank. weights = control_flow_ops.cond( math_ops.equal(array_ops.rank(weights), math_ops.add(array_ops.rank(predictions), 1)), diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 9ad2bf998b4..344a5921065 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1663,13 +1663,13 @@ def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable= Args: _sentinel: Used to prevent positional parameters. Internal, do not use. - labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-2}]` and dtype `int32` or - `int64`. Each entry in `labels` must be an index in `[0, num_classes)`. - Other values will raise an exception when this op is run on CPU, and - return `NaN` for corresponding corresponding loss and gradient rows - on GPU. - logits: Unscaled log probabilities of rank `r` and shape - `[d_0, d_1, ..., d_{r-2}, num_classes]` and dtype `float32` or `float64`. + labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of + `labels` and result) and dtype `int32` or `int64`. Each entry in `labels` + must be an index in `[0, num_classes)`. Other values will raise an + exception when this op is run on CPU, and return `NaN` for corresponding + loss and gradient rows on GPU. + logits: Unscaled log probabilities of shape + `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`. name: A name for the operation (optional). Returns: