Change weight broadcasting for losses.sparse_softmax_cross_entropy to be the same as sparse_softmax_cross_entropy_with_logits. This means allowing greater than rank 2 inputs.
Also, fix some pydoc. Change: 144898465
This commit is contained in:
parent
d8fd68242e
commit
91a6f2f4d6
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import confusion_matrix
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -215,5 +216,239 @@ class ConfusionMatrixTest(test.TestCase):
|
||||
self.assertEqual(tf_cm.dtype, np.int64)
|
||||
|
||||
|
||||
class RemoveSqueezableDimensionsTest(test.TestCase):
|
||||
|
||||
def testBothScalarShape(self):
|
||||
label_values = 1.0
|
||||
prediction_values = 0.0
|
||||
static_labels, static_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
label_values, prediction_values))
|
||||
|
||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.float32)
|
||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.float32)
|
||||
dynamic_labels, dynamic_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
labels_placeholder, predictions_placeholder))
|
||||
|
||||
with self.test_session():
|
||||
self.assertAllEqual(label_values, static_labels.eval())
|
||||
self.assertAllEqual(prediction_values, static_predictions.eval())
|
||||
feed_dict = {
|
||||
labels_placeholder: label_values,
|
||||
predictions_placeholder: prediction_values
|
||||
}
|
||||
self.assertAllEqual(
|
||||
label_values, dynamic_labels.eval(feed_dict=feed_dict))
|
||||
self.assertAllEqual(
|
||||
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
|
||||
|
||||
def testSameShape(self):
|
||||
label_values = np.ones(shape=(2, 3, 1))
|
||||
prediction_values = np.zeros_like(label_values)
|
||||
static_labels, static_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
label_values, prediction_values))
|
||||
|
||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
dynamic_labels, dynamic_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
labels_placeholder, predictions_placeholder))
|
||||
|
||||
with self.test_session():
|
||||
self.assertAllEqual(label_values, static_labels.eval())
|
||||
self.assertAllEqual(prediction_values, static_predictions.eval())
|
||||
feed_dict = {
|
||||
labels_placeholder: label_values,
|
||||
predictions_placeholder: prediction_values
|
||||
}
|
||||
self.assertAllEqual(
|
||||
label_values, dynamic_labels.eval(feed_dict=feed_dict))
|
||||
self.assertAllEqual(
|
||||
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
|
||||
|
||||
def testSameShapeExpectedRankDiff0(self):
|
||||
label_values = np.ones(shape=(2, 3, 1))
|
||||
prediction_values = np.zeros_like(label_values)
|
||||
static_labels, static_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
label_values, prediction_values, expected_rank_diff=0))
|
||||
|
||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
dynamic_labels, dynamic_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
labels_placeholder, predictions_placeholder, expected_rank_diff=0))
|
||||
|
||||
with self.test_session():
|
||||
self.assertAllEqual(label_values, static_labels.eval())
|
||||
self.assertAllEqual(prediction_values, static_predictions.eval())
|
||||
feed_dict = {
|
||||
labels_placeholder: label_values,
|
||||
predictions_placeholder: prediction_values
|
||||
}
|
||||
self.assertAllEqual(
|
||||
label_values, dynamic_labels.eval(feed_dict=feed_dict))
|
||||
self.assertAllEqual(
|
||||
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
|
||||
|
||||
def testSqueezableLabels(self):
|
||||
label_values = np.ones(shape=(2, 3, 1))
|
||||
prediction_values = np.zeros(shape=(2, 3))
|
||||
static_labels, static_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
label_values, prediction_values))
|
||||
|
||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
dynamic_labels, dynamic_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
labels_placeholder, predictions_placeholder))
|
||||
|
||||
expected_label_values = np.reshape(label_values, newshape=(2, 3))
|
||||
with self.test_session():
|
||||
self.assertAllEqual(expected_label_values, static_labels.eval())
|
||||
self.assertAllEqual(prediction_values, static_predictions.eval())
|
||||
feed_dict = {
|
||||
labels_placeholder: label_values,
|
||||
predictions_placeholder: prediction_values
|
||||
}
|
||||
self.assertAllEqual(
|
||||
expected_label_values, dynamic_labels.eval(feed_dict=feed_dict))
|
||||
self.assertAllEqual(
|
||||
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
|
||||
|
||||
def testSqueezableLabelsExpectedRankDiffPlus1(self):
|
||||
label_values = np.ones(shape=(2, 3, 1))
|
||||
prediction_values = np.zeros(shape=(2, 3, 5))
|
||||
static_labels, static_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
label_values, prediction_values, expected_rank_diff=1))
|
||||
|
||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
dynamic_labels, dynamic_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
labels_placeholder, predictions_placeholder, expected_rank_diff=1))
|
||||
|
||||
expected_label_values = np.reshape(label_values, newshape=(2, 3))
|
||||
with self.test_session():
|
||||
self.assertAllEqual(expected_label_values, static_labels.eval())
|
||||
self.assertAllEqual(prediction_values, static_predictions.eval())
|
||||
feed_dict = {
|
||||
labels_placeholder: label_values,
|
||||
predictions_placeholder: prediction_values
|
||||
}
|
||||
self.assertAllEqual(
|
||||
expected_label_values, dynamic_labels.eval(feed_dict=feed_dict))
|
||||
self.assertAllEqual(
|
||||
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
|
||||
|
||||
def testSqueezablePredictions(self):
|
||||
label_values = np.ones(shape=(2, 3))
|
||||
prediction_values = np.zeros(shape=(2, 3, 1))
|
||||
static_labels, static_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
label_values, prediction_values))
|
||||
|
||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
dynamic_labels, dynamic_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
labels_placeholder, predictions_placeholder))
|
||||
|
||||
expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
|
||||
with self.test_session():
|
||||
self.assertAllEqual(label_values, static_labels.eval())
|
||||
self.assertAllEqual(expected_prediction_values, static_predictions.eval())
|
||||
feed_dict = {
|
||||
labels_placeholder: label_values,
|
||||
predictions_placeholder: prediction_values
|
||||
}
|
||||
self.assertAllEqual(
|
||||
label_values, dynamic_labels.eval(feed_dict=feed_dict))
|
||||
self.assertAllEqual(
|
||||
expected_prediction_values,
|
||||
dynamic_predictions.eval(feed_dict=feed_dict))
|
||||
|
||||
def testSqueezablePredictionsExpectedRankDiffMinus1(self):
|
||||
label_values = np.ones(shape=(2, 3, 5))
|
||||
prediction_values = np.zeros(shape=(2, 3, 1))
|
||||
static_labels, static_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
label_values, prediction_values, expected_rank_diff=-1))
|
||||
|
||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
dynamic_labels, dynamic_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
labels_placeholder, predictions_placeholder, expected_rank_diff=-1))
|
||||
|
||||
expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
|
||||
with self.test_session():
|
||||
self.assertAllEqual(label_values, static_labels.eval())
|
||||
self.assertAllEqual(expected_prediction_values, static_predictions.eval())
|
||||
feed_dict = {
|
||||
labels_placeholder: label_values,
|
||||
predictions_placeholder: prediction_values
|
||||
}
|
||||
self.assertAllEqual(
|
||||
label_values, dynamic_labels.eval(feed_dict=feed_dict))
|
||||
self.assertAllEqual(
|
||||
expected_prediction_values,
|
||||
dynamic_predictions.eval(feed_dict=feed_dict))
|
||||
|
||||
def testUnsqueezableLabels(self):
|
||||
label_values = np.ones(shape=(2, 3, 2))
|
||||
prediction_values = np.zeros(shape=(2, 3))
|
||||
with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"):
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
label_values, prediction_values)
|
||||
|
||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
dynamic_labels, dynamic_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
labels_placeholder, predictions_placeholder))
|
||||
|
||||
with self.test_session():
|
||||
feed_dict = {
|
||||
labels_placeholder: label_values,
|
||||
predictions_placeholder: prediction_values
|
||||
}
|
||||
with self.assertRaisesRegexp(
|
||||
errors_impl.InvalidArgumentError,
|
||||
"Tried to explicitly squeeze dimension 2"):
|
||||
dynamic_labels.eval(feed_dict=feed_dict)
|
||||
self.assertAllEqual(
|
||||
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
|
||||
|
||||
def testUnsqueezablePredictions(self):
|
||||
label_values = np.ones(shape=(2, 3))
|
||||
prediction_values = np.zeros(shape=(2, 3, 2))
|
||||
with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"):
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
label_values, prediction_values)
|
||||
|
||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
dynamic_labels, dynamic_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
labels_placeholder, predictions_placeholder))
|
||||
|
||||
with self.test_session():
|
||||
feed_dict = {
|
||||
labels_placeholder: label_values,
|
||||
predictions_placeholder: prediction_values
|
||||
}
|
||||
self.assertAllEqual(
|
||||
label_values, dynamic_labels.eval(feed_dict=feed_dict))
|
||||
with self.assertRaisesRegexp(
|
||||
errors_impl.InvalidArgumentError,
|
||||
"Tried to explicitly squeeze dimension 2"):
|
||||
dynamic_predictions.eval(feed_dict=feed_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -303,7 +303,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
weights = 2.3
|
||||
with self.test_session():
|
||||
loss = losses.sparse_softmax_cross_entropy(
|
||||
labels, logits, constant_op.constant(weights, shape=(1, 1)))
|
||||
labels, logits, constant_op.constant((weights,)))
|
||||
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithPlaceholderForWeights(self):
|
||||
@ -432,7 +432,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
labels = constant_op.constant([[0, 1], [2, 3]])
|
||||
weights = constant_op.constant(1.2)
|
||||
|
||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||
with self.assertRaisesRegexp(ValueError, 'dimension'):
|
||||
losses.sparse_softmax_cross_entropy(
|
||||
labels, logits, weights=weights).eval()
|
||||
|
||||
@ -1241,6 +1241,9 @@ class ComputeWeightedLossTest(test.TestCase):
|
||||
def testInvalid1x2Weight(self):
|
||||
self._test_invalid_weights((17.0, 3.0,),)
|
||||
|
||||
def testInvalidScalar1DWeight(self):
|
||||
self._test_invalid_weights((17.0,),)
|
||||
|
||||
def _test_valid_weights(self, weights):
|
||||
with ops.Graph().as_default():
|
||||
self.assertEqual(0, len(util.get_losses()))
|
||||
|
@ -32,8 +32,19 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
|
||||
|
||||
def remove_squeezable_dimensions(labels, predictions, name=None):
|
||||
"""Squeeze last dim if ranks of `predictions` and `labels` differ by 1.
|
||||
def remove_squeezable_dimensions(
|
||||
labels, predictions, expected_rank_diff=0, name=None):
|
||||
"""Squeeze last dim if ranks differ from expected by exactly 1.
|
||||
|
||||
In the common case where we expect shapes to match, `expected_rank_diff`
|
||||
defaults to 0, and we squeeze the last dimension of the larger rank if they
|
||||
differ by 1.
|
||||
|
||||
But, for example, if `labels` contains class IDs and `predictions` contains 1
|
||||
probability per class, we expect `predictions` to have 1 more dimension than
|
||||
`labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze
|
||||
`labels` if `rank(predictions) - rank(labels) == 0`, and
|
||||
`predictions` if `rank(predictions) - rank(labels) == 2`.
|
||||
|
||||
This will use static shape if available. Otherwise, it will add graph
|
||||
operations, which could result in a performance hit.
|
||||
@ -41,6 +52,7 @@ def remove_squeezable_dimensions(labels, predictions, name=None):
|
||||
Args:
|
||||
labels: Label values, a `Tensor` whose dimensions match `predictions`.
|
||||
predictions: Predicted values, a `Tensor` of arbitrary dimensions.
|
||||
expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
|
||||
name: Name of the op.
|
||||
|
||||
Returns:
|
||||
@ -57,10 +69,10 @@ def remove_squeezable_dimensions(labels, predictions, name=None):
|
||||
if (labels_rank is not None) and (predictions_rank is not None):
|
||||
# Use static rank.
|
||||
rank_diff = predictions_rank - labels_rank
|
||||
if rank_diff == -1:
|
||||
labels = array_ops.squeeze(labels, [-1])
|
||||
elif rank_diff == 1:
|
||||
if rank_diff == expected_rank_diff + 1:
|
||||
predictions = array_ops.squeeze(predictions, [-1])
|
||||
elif rank_diff == expected_rank_diff - 1:
|
||||
labels = array_ops.squeeze(labels, [-1])
|
||||
return labels, predictions
|
||||
|
||||
# Use dynamic rank.
|
||||
@ -68,13 +80,13 @@ def remove_squeezable_dimensions(labels, predictions, name=None):
|
||||
if (predictions_rank is None) or (
|
||||
predictions_shape.dims[-1].is_compatible_with(1)):
|
||||
predictions = control_flow_ops.cond(
|
||||
math_ops.equal(1, rank_diff),
|
||||
math_ops.equal(expected_rank_diff + 1, rank_diff),
|
||||
lambda: array_ops.squeeze(predictions, [-1]),
|
||||
lambda: predictions)
|
||||
if (labels_rank is None) or (
|
||||
labels_shape.dims[-1].is_compatible_with(1)):
|
||||
labels = control_flow_ops.cond(
|
||||
math_ops.equal(-1, rank_diff),
|
||||
math_ops.equal(expected_rank_diff - 1, rank_diff),
|
||||
lambda: array_ops.squeeze(labels, [-1]),
|
||||
lambda: labels)
|
||||
return labels, predictions
|
||||
|
@ -22,6 +22,8 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:confusion_matrix",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import confusion_matrix
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import nn_ops
|
||||
@ -130,16 +132,11 @@ def compute_weighted_loss(
|
||||
losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES):
|
||||
"""Computes the weighted loss.
|
||||
|
||||
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
|
||||
not work as advertised, you'll wind up with weighted sum instead of weighted
|
||||
mean for any but the last dimension. This will be cleaned up soon, so please
|
||||
do not rely on the current behavior for anything but the shapes documented for
|
||||
`weights` below.
|
||||
|
||||
Args:
|
||||
losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
|
||||
weights: `Tensor` of shape `[]`, `[batch_size]` or
|
||||
`[batch_size, d1, ... dK]`, where K < N.
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`losses`, and must be broadcastable to `losses` (i.e., all dimensions must
|
||||
be either `1`, or the same as the corresponding `losses` dimension).
|
||||
scope: the scope for the operations performed in computing the loss.
|
||||
loss_collection: the loss will be added to these collections.
|
||||
|
||||
@ -180,17 +177,12 @@ def absolute_difference(
|
||||
measurable element of `predictions` is scaled by the corresponding value of
|
||||
`weights`.
|
||||
|
||||
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
|
||||
not work as advertised, you'll wind up with weighted sum instead of weighted
|
||||
mean for any but the last dimension. This will be cleaned up soon, so please
|
||||
do not rely on the current behavior for anything but the shapes documented for
|
||||
`weights` below.
|
||||
|
||||
Args:
|
||||
labels: The ground truth output tensor, same dimensions as 'predictions'.
|
||||
predictions: The predicted outputs.
|
||||
weights: Coefficients for the loss a scalar, a tensor of shape
|
||||
`[batch_size]` or a tensor whose shape matches `predictions`.
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||
be either `1`, or the same as the corresponding `losses` dimension).
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
loss_collection: collection to which this loss will be added.
|
||||
|
||||
@ -218,18 +210,13 @@ def cosine_distance(
|
||||
Note that the function assumes that `predictions` and `labels` are already
|
||||
unit-normalized.
|
||||
|
||||
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
|
||||
not work as advertised, you'll wind up with weighted sum instead of weighted
|
||||
mean for any but the last dimension. This will be cleaned up soon, so please
|
||||
do not rely on the current behavior for anything but the shapes documented for
|
||||
`weights` below.
|
||||
|
||||
Args:
|
||||
labels: `Tensor` whose shape matches 'predictions'
|
||||
predictions: An arbitrary matrix.
|
||||
dim: The dimension along which the cosine distance is computed.
|
||||
weights: Coefficients for the loss a scalar, a tensor of shape
|
||||
`[batch_size]` or a tensor whose shape matches `predictions`.
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||
be either `1`, or the same as the corresponding `losses` dimension).
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
loss_collection: collection to which this loss will be added.
|
||||
|
||||
@ -257,18 +244,13 @@ def hinge_loss(labels, logits, weights=1.0, scope=None,
|
||||
loss_collection=ops.GraphKeys.LOSSES):
|
||||
"""Adds a hinge loss to the training procedure.
|
||||
|
||||
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
|
||||
not work as advertised, you'll wind up with weighted sum instead of weighted
|
||||
mean for any but the last dimension. This will be cleaned up soon, so please
|
||||
do not rely on the current behavior for anything but the shapes documented for
|
||||
`weights` below.
|
||||
|
||||
Args:
|
||||
labels: The ground truth output tensor. Its shape should match the shape of
|
||||
logits. The values of the tensor are expected to be 0.0 or 1.0.
|
||||
logits: The logits, a float tensor.
|
||||
weights: Coefficients for the loss a scalar, a tensor of shape
|
||||
`[batch_size]` or a tensor whose shape matches `predictions`.
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||
be either `1`, or the same as the corresponding `losses` dimension).
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
loss_collection: collection to which the loss will be added.
|
||||
|
||||
@ -302,17 +284,12 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
|
||||
measurable element of `predictions` is scaled by the corresponding value of
|
||||
`weights`.
|
||||
|
||||
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
|
||||
not work as advertised, you'll wind up with weighted sum instead of weighted
|
||||
mean for any but the last dimension. This will be cleaned up soon, so please
|
||||
do not rely on the current behavior for anything but the shapes documented for
|
||||
`weights` below.
|
||||
|
||||
Args:
|
||||
labels: The ground truth output tensor, same dimensions as 'predictions'.
|
||||
predictions: The predicted outputs.
|
||||
weights: Coefficients for the loss a scalar, a tensor of shape
|
||||
`[batch_size]` or a tensor whose shape matches `predictions`.
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||
be either `1`, or the same as the corresponding `losses` dimension).
|
||||
epsilon: A small increment to add to avoid taking a log of zero.
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
loss_collection: collection to which the loss will be added.
|
||||
@ -427,17 +404,12 @@ def mean_squared_error(labels, predictions, weights=1.0, scope=None,
|
||||
measurable element of `predictions` is scaled by the corresponding value of
|
||||
`weights`.
|
||||
|
||||
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
|
||||
not work as advertised, you'll wind up with weighted sum instead of weighted
|
||||
mean for any but the last dimension. This will be cleaned up soon, so please
|
||||
do not rely on the current behavior for anything but the shapes documented for
|
||||
`weights` below.
|
||||
|
||||
Args:
|
||||
labels: The ground truth output tensor, same dimensions as 'predictions'.
|
||||
predictions: The predicted outputs.
|
||||
weights: Coefficients for the loss a scalar, a tensor of shape
|
||||
`[batch_size]` or a tensor whose shape matches `predictions`.
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||
be either `1`, or the same as the corresponding `losses` dimension).
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
loss_collection: collection to which the loss will be added.
|
||||
|
||||
@ -467,12 +439,6 @@ def sigmoid_cross_entropy(
|
||||
tensor of shape `[batch_size]`, then the loss weights apply to each
|
||||
corresponding sample.
|
||||
|
||||
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
|
||||
not work as advertised, you'll wind up with weighted sum instead of weighted
|
||||
mean for any but the last dimension. This will be cleaned up soon, so please
|
||||
do not rely on the current behavior for anything but the shapes documented for
|
||||
`weights` below.
|
||||
|
||||
If `label_smoothing` is nonzero, smooth the labels towards 1/2:
|
||||
|
||||
new_multiclass_labels = multiclass_labels * (1 - label_smoothing)
|
||||
@ -482,8 +448,9 @@ def sigmoid_cross_entropy(
|
||||
multi_class_labels: `[batch_size, num_classes]` target integer labels in
|
||||
`(0, 1)`.
|
||||
logits: `[batch_size, num_classes]` logits outputs of the network.
|
||||
weights: Coefficients for the loss. This must be of shape `[]`,
|
||||
`[batch_size]` or `[batch_size, num_classes]`.
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||
be either `1`, or the same as the corresponding `losses` dimension).
|
||||
label_smoothing: If greater than `0` then smooth the labels.
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
loss_collection: collection to which the loss will be added.
|
||||
@ -522,12 +489,6 @@ def softmax_cross_entropy(
|
||||
tensor of shape `[batch_size]`, then the loss weights apply to each
|
||||
corresponding sample.
|
||||
|
||||
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
|
||||
not work as advertised, you'll wind up with weighted sum instead of weighted
|
||||
mean for any but the last dimension. This will be cleaned up soon, so please
|
||||
do not rely on the current behavior for anything but the shapes documented for
|
||||
`weights` below.
|
||||
|
||||
If `label_smoothing` is nonzero, smooth the labels towards 1/num_classes:
|
||||
new_onehot_labels = onehot_labels * (1 - label_smoothing)
|
||||
+ label_smoothing / num_classes
|
||||
@ -535,8 +496,10 @@ def softmax_cross_entropy(
|
||||
Args:
|
||||
onehot_labels: `[batch_size, num_classes]` target one-hot-encoded labels.
|
||||
logits: [batch_size, num_classes] logits outputs of the network .
|
||||
weights: Coefficients for the loss. This must be of shape `[]`,
|
||||
`[batch_size]` or `[batch_size, num_classes]`.
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`onehot_labels`, and must be broadcastable to `onehot_labels` (i.e., all
|
||||
dimensions must be either `1`, or the same as the corresponding `losses`
|
||||
dimension).
|
||||
label_smoothing: If greater than 0 then smooth the labels.
|
||||
scope: the scope for the operations performed in computing the loss.
|
||||
loss_collection: collection to which the loss will be added.
|
||||
@ -567,6 +530,57 @@ def softmax_cross_entropy(
|
||||
return compute_weighted_loss(losses, weights, scope, loss_collection)
|
||||
|
||||
|
||||
# TODO(ptucker): Merge this with similar method in metrics_impl.
|
||||
def _remove_squeezable_dimensions(
|
||||
labels, predictions, weights=None, expected_rank_diff=0):
|
||||
"""Internal version of _remove_squeezable_dimensions which handles weights.
|
||||
|
||||
Squeezes `predictions` and `labels` if their ranks differ from expected by
|
||||
exactly 1.
|
||||
Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
|
||||
|
||||
This will use static shape if available. Otherwise, it will add graph
|
||||
operations, which could result in a performance hit.
|
||||
|
||||
Args:
|
||||
labels: Label values, a `Tensor` whose dimensions match `predictions`.
|
||||
predictions: Predicted values, a `Tensor` of arbitrary dimensions.
|
||||
weights: Optional weight `Tensor`. It will be squeezed if it's not scalar,
|
||||
and its rank is 1 more than the new rank of `labels`.
|
||||
expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
|
||||
|
||||
Returns:
|
||||
Tuple of `predictions`, `labels` and `weights`, possibly with the last
|
||||
dimension squeezed.
|
||||
"""
|
||||
labels, predictions = confusion_matrix.remove_squeezable_dimensions(
|
||||
labels, predictions, expected_rank_diff=expected_rank_diff)
|
||||
|
||||
if weights is not None:
|
||||
weights = ops.convert_to_tensor(weights)
|
||||
labels_rank = labels.get_shape().ndims
|
||||
weights_shape = weights.get_shape()
|
||||
weights_rank = weights_shape.ndims
|
||||
|
||||
if (labels_rank is not None) and (weights_rank is not None):
|
||||
# Use static rank.
|
||||
rank_diff = weights_rank - labels_rank
|
||||
if rank_diff == 1:
|
||||
weights = array_ops.squeeze(weights, [-1])
|
||||
return labels, predictions, weights
|
||||
|
||||
# Use dynamic rank.
|
||||
rank_diff = array_ops.rank(weights) - array_ops.rank(labels)
|
||||
if (weights_rank is None) or (
|
||||
weights_shape.dims[-1].is_compatible_with(1)):
|
||||
weights = control_flow_ops.cond(
|
||||
math_ops.equal(1, rank_diff),
|
||||
lambda: array_ops.squeeze(weights, [-1]),
|
||||
lambda: weights)
|
||||
|
||||
return labels, predictions, weights
|
||||
|
||||
|
||||
def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None,
|
||||
loss_collection=ops.GraphKeys.LOSSES):
|
||||
"""Cross-entropy loss using `tf.nn.sparse_softmax_cross_entropy_with_logits`.
|
||||
@ -576,18 +590,16 @@ def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None,
|
||||
tensor of shape [`batch_size`], then the loss weights apply to each
|
||||
corresponding sample.
|
||||
|
||||
WARNING: `weights` also supports dimensions of 1, but the broadcasting does
|
||||
not work as advertised, you'll wind up with weighted sum instead of weighted
|
||||
mean for any but the last dimension. This will be cleaned up soon, so please
|
||||
do not rely on the current behavior for anything but the shapes documented for
|
||||
`weights` below.
|
||||
|
||||
Args:
|
||||
labels: [batch_size, 1] or [batch_size] target labels of dtype `int32` or
|
||||
`int64` in the range `[0, num_classes)`.
|
||||
logits: [batch_size, num_classes] logits outputs of the network .
|
||||
weights: Coefficients for the loss. This must be scalar or of shape
|
||||
`[batch_size, 1]`.
|
||||
labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
|
||||
`labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
|
||||
must be an index in `[0, num_classes)`. Other values will raise an
|
||||
exception when this op is run on CPU, and return `NaN` for corresponding
|
||||
loss and gradient rows on GPU.
|
||||
logits: Unscaled log probabilities of shape
|
||||
`[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`.
|
||||
weights: Coefficients for the loss. This must be scalar or of same rank as
|
||||
`labels`
|
||||
scope: the scope for the operations performed in computing the loss.
|
||||
loss_collection: collection to which the loss will be added.
|
||||
|
||||
@ -600,12 +612,12 @@ def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None,
|
||||
"""
|
||||
with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
|
||||
(logits, labels, weights)) as scope:
|
||||
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
|
||||
|
||||
# As documented above in Args, labels contain class IDs and logits contains
|
||||
# 1 probability per class ID, so we expect rank(logits) - rank(labels) == 1;
|
||||
# therefore, expected_rank_diff=1.
|
||||
labels, logits, weights = _remove_squeezable_dimensions(
|
||||
labels, logits, weights, expected_rank_diff=1)
|
||||
losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
|
||||
logits=logits,
|
||||
name="xentropy")
|
||||
# Reshape losses to [batch_size, 1] to be consistent with weights.
|
||||
# TODO(ptucker): reshape to (-1, 1) more efficient?
|
||||
losses = array_ops.reshape(losses, shape=[array_ops.shape(losses)[0], 1])
|
||||
return compute_weighted_loss(losses, weights, scope, loss_collection)
|
||||
|
@ -87,7 +87,7 @@ def _remove_squeezable_dimensions(labels, predictions, weights):
|
||||
weights = array_ops.squeeze(weights, [-1])
|
||||
elif (weights_rank is None) or (
|
||||
weights_shape.dims[-1].is_compatible_with(1)):
|
||||
# Use dynamic rank
|
||||
# Use dynamic rank.
|
||||
weights = control_flow_ops.cond(
|
||||
math_ops.equal(array_ops.rank(weights),
|
||||
math_ops.add(array_ops.rank(predictions), 1)),
|
||||
|
@ -1663,13 +1663,13 @@ def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=
|
||||
|
||||
Args:
|
||||
_sentinel: Used to prevent positional parameters. Internal, do not use.
|
||||
labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-2}]` and dtype `int32` or
|
||||
`int64`. Each entry in `labels` must be an index in `[0, num_classes)`.
|
||||
Other values will raise an exception when this op is run on CPU, and
|
||||
return `NaN` for corresponding corresponding loss and gradient rows
|
||||
on GPU.
|
||||
logits: Unscaled log probabilities of rank `r` and shape
|
||||
`[d_0, d_1, ..., d_{r-2}, num_classes]` and dtype `float32` or `float64`.
|
||||
labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
|
||||
`labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
|
||||
must be an index in `[0, num_classes)`. Other values will raise an
|
||||
exception when this op is run on CPU, and return `NaN` for corresponding
|
||||
loss and gradient rows on GPU.
|
||||
logits: Unscaled log probabilities of shape
|
||||
`[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
|
Loading…
Reference in New Issue
Block a user