Adding several loss functions including cosine_distance_loss, log_loss, softmax_cross_entropy_loss, and sum_of_pairwise_squares_loss. Refactoring old losses for consistency.

Change: 120130871
This commit is contained in:
A. Unique TensorFlower 2016-04-18 08:37:55 -08:00 committed by TensorFlower Gardener
parent c2d9cb1d08
commit ec89b0c218
3 changed files with 1170 additions and 402 deletions

View File

@ -18,5 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.losses.python.losses.loss_ops import *
from tensorflow.contrib.losses.python.losses.loss_ops import absolute_difference
from tensorflow.contrib.losses.python.losses.loss_ops import cosine_distance
from tensorflow.contrib.losses.python.losses.loss_ops import log
from tensorflow.contrib.losses.python.losses.loss_ops import sigmoid_cross_entropy
from tensorflow.contrib.losses.python.losses.loss_ops import softmax_cross_entropy
from tensorflow.contrib.losses.python.losses.loss_ops import sum_of_pairwise_squares
from tensorflow.contrib.losses.python.losses.loss_ops import sum_of_squares

View File

@ -14,272 +14,552 @@
# ==============================================================================
"""## Loss operations for use in neural networks.
The loss ops measure error for use in neural networks. These losses
can be used for measuring accuracy of a network in a regression task
or for regularization purposes (e.g., weight decay).
All of the loss functions take a pair of predictions and ground truth labels,
from which the loss is computed. It is assumed that the shape of both these
tensors is of the form [batch_size, d1, ... dN] where `batch_size` is the number
of samples in the batch and `d1` ... `dN` are the remaining dimensions.
These loss ops are, by design, minimal, enabling flexibility in how
their output can be used.
It is common, when training with multiple loss functions, to adjust the relative
strengths of individual losses. This is performed by rescaling the losses via
a `weight` parameter passed to the loss functions. For example, if we were
training with both log_loss and sum_of_squares_loss, and we wished that the
log_loss penalty be twice as severe as the sum_of_squares_loss, we would
implement this as:
@@absolute
@@squared
@@logistic
@@softmax
# Explicitely set the weight.
tf.contrib.losses.log(predictions, targets, weight=2.0)
# Uses default weight of 1.0
tf.contrib.losses.sum_of_squares(predictions, targets)
While specifying a scalar loss rescales the loss over the entire batch,
we sometimes want to rescale the loss per batch sample. For example, if we have
certain examples that matter more to us to get correctly, we might want to have
a higher loss that other samples whose mistakes matter less. In this case, we
can provide a weight vector of length `batch_size` which results in the loss
for each sample in the batch being scaled by the corresponding weight element.
For example, consider the case of a classification problem where we want to
maximize our accuracy but we especially interested in obtaining high accuracy
for a specific class:
inputs, labels = LoadData(batch_size=3)
logits = MyModelPredictions(inputs)
# Ensures that the loss for examples whose ground truth class is `3` is 5x
# higher than the loss for all other examples.
weight = tf.mul(4, tf.cast(tf.equal(labels, 3), tf.float32)) + 1
onehot_labels = tf.one_hot(labels, num_classes=5)
tf.contrib.losses.softmax_cross_entropy(logits, onehot_labels, weight=weight)
Finally, in certain cases, we may want to specify a different loss for every
single measurable value. For example, if we are performing per-pixel depth
prediction, or per-pixel denoising, a single batch sample has P values where P
is the number of pixels in the image. For many losses, the number of measurable
values matches the number of elements in the predictions and targets tensors.
For others, such as softmax_cross_entropy and cosine_distance, the
loss functions reduces the dimensions of the inputs to produces a tensor of
losses for each measurable value. For example, softmax_cross_entropy takes as
input predictions and labels of dimension [batch_size, num_classes] but the
number of measurable values is [batch_size]. Consequently, when passing a weight
tensor to specify a different loss for every measurable value, the dimension of
the tensor will depend on the loss being used.
For a concrete example, consider the case of per-pixel depth prediction where
certain ground truth depth values are missing (due to sensor noise in the
capture process). In this case, we want to assign zero weight to losses for
these predictions.
# 'depths' that are missing have a value of 0:
images, depths = LoadData(...)
predictions = MyModelPredictions(images)
weight = tf.cast(tf.greater(depths, 0), tf.float32)
tf.contrib.losses.sum_of_squares(predictions, depths, weight)
Note that when using weights for the losses, the final average is computed
by rescaling the losses by the weights and then dividing by the total number of
non-zero samples. For an arbitrary set of weights, this may not necessarily
produce a weighted average. Instead, it simply and transparently rescales the
per-element losses before averaging over the number of observations. For example
if the losses computed by the loss function is an array [4, 1, 2, 3] and the
weights are an array [1, 0.5, 3, 9], then the average loss is:
(4*1 + 1*0.5 + 2*3 + 3*9) / 4
However, with a single loss function and an arbitrary set of weights, one can
still easily create a loss function such that the resulting loss is a
weighted average over the individual prediction errors:
images, labels = LoadData(...)
predictions = MyModelPredictions(images)
weight = MyComplicatedWeightingFunction(labels)
weight = tf.div(weight, tf.size(weight))
tf.contrib.losses.sum_of_squares(predictions, depths, weight)
@@absolute_difference
@@cosine_distance
@@log
@@sigmoid_cross_entropy
@@softmax_cross_entropy
@@sum_of_pairwise_squares
@@sum_of_squares
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.framework import tensor_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
__all__ = ["absolute", "squared", "logistic", "softmax"]
__all__ = [
"absolute_difference",
"cosine_distance",
"log",
"sigmoid_cross_entropy",
"softmax_cross_entropy",
"sum_of_pairwise_squares",
"sum_of_squares",
]
def _reduce_batch(x, reduce_fn, name=None):
"""Given a tensor `x`, calls reduce_fn to reduce it across dimensions.
Given a tensor with number of dimensions > 1, _reduce_batch will reduce the
tensor across all dimensions except for dimension 0. As an example, given a
tensor of shape [batch_size, d1, d2], this function will reduce across
dimensions d1 and d2, returning a tensor of shape [batch_size].
Tensors of dimension 1 are returned as-is, while tensors of dimension 0
raise a ValueError.
def _scale_losses(losses, weight):
"""Computes the scaled loss.
Args:
x: A `Tensor` with dimension > 0.
reduce_fn: A math_ops reduce function that takes arguments of
`x`, `reduction_indices`, and `name`.
name: A name for the operation (optional).
losses: A `Tensor` of size [batch_size, d1, ... dN].
weight: A `Tensor` of size [1], [batch_size] or [batch_size, d1, ... dN].
The `losses` are reduced (tf.reduce_sum) until its dimension matches
that of `weight` at which point the reduced `losses` are element-wise
multiplied by `weight` and a final reduce_sum is computed on the result.
Conceptually, this operation is equivalent to broadcasting (tiling)
`weight` to be the same size as `losses`, performing an element-wise
multiplication, and summing the result.
Returns:
A `Tensor` with values reduced by reduce_fn across all dimensions > 0.
A scalar tf.float32 `Tensor` whose value represents the sum of the scaled
`losses`.
"""
# First, compute the sum of the losses over all elements:
start_index = max(0, weight.get_shape().ndims)
reduction_indices = range(start_index, losses.get_shape().ndims)
reduced_losses = math_ops.reduce_sum(losses,
reduction_indices=reduction_indices)
reduced_losses = math_ops.mul(reduced_losses, weight)
return math_ops.reduce_sum(reduced_losses)
def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.
Args:
losses: A tensor whose elements contain individual loss measurements.
num_present: The number of measurable losses in the tensor.
Returns:
A scalar representing the mean of the losses. If `num_present` is zero,
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
return math_ops.select(num_present > 0,
math_ops.div(total_loss, num_present),
array_ops.zeros_like(total_loss),
name="value")
def _compute_weighted_loss(losses, weight):
"""Computes the weighted loss.
Args:
losses: A tensor of size [batch_size, d1, ... dN].
weight: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
Returns:
A scalar `Tensor` that returns the weighted loss.
Raises:
ValueError: If `x` has dimension 0.
ValueError: If the weight shape is not compatible with the losses shape or
if the number of dimensions (rank) of either losses or weight is missing.
"""
x = ops.convert_to_tensor(x, name="x")
with ops.op_scope([x], name, "reduce_batch"):
ndims = x.get_shape().ndims
if ndims == 0:
raise ValueError("Cannot reduce a scalar into batches.")
elif ndims == 1:
return x # Don't include a useless reduction.
elif ndims:
reduction_indices = math_ops.range(1, ndims)
shape = [x.get_shape().dims[0]]
else:
reduction_indices = math_ops.range(1, array_ops.size(array_ops.shape(x)))
shape = [None] # We don't know much about the shape, but it is rank 1.
result = reduce_fn(x, reduction_indices=reduction_indices)
losses = math_ops.to_float(losses)
weight = math_ops.to_float(ops.convert_to_tensor(weight))
# Give a shape hint in case we have extra information.
result.set_shape(shape)
return result
if losses.get_shape().ndims is None:
raise ValueError("losses.get_shape().ndims cannot be None")
if weight.get_shape().ndims is None:
raise ValueError("weight.get_shape().ndims cannot be None")
total_loss = _scale_losses(losses, weight)
num_present = _num_present(losses, weight)
return _safe_mean(total_loss, num_present)
def _reduce_batch_sum(x, name=None):
"""Given a tensor `x`, sums across all dimensions except dimension 0.
def _num_present(losses, weight, per_batch=False):
"""Computes the number of elements in the loss function induced by `weight`.
Given a tensor with the number of dimensions > 1, this will sum across all
dimensions except for dimension 0. This function is useful for summing the
loss (error) across all examples in a batch when training. As an example,
given a tensor of shape [batch_size, d1, d2], this function will sum across
dimensions d1 and d2, returning a tensor of shape [batch_size].
Tensors of dimension 1 are returned as-is, while tensors of dimension 0
raise a ValueError.
A given weight tensor induces different numbers of usable elements in the
`losses` tensor. The `weight` tensor is broadcast across `losses` for all
possible dimensions. For example, if `losses` is a tensor of dimension
[4, 5, 6, 3] and weight is a tensor of size [4, 5], then weight is, in effect,
tiled to match the size of `losses`. Following this effective tile, the total
number of present elements is the number of non-zero weights.
Args:
x: A `Tensor` with dimension > 0.
name: A name for the operation (optional).
losses: A tensor of size [batch_size, d1, ... dN].
weight: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
per_batch: Whether to return the number of elements per batch or as a sum
total.
Returns:
A `Tensor` with values summed across all dimensions > 0.
The number of present (non-zero) elements in the losses tensor. If
`per_batch` is True, the value is returned as a tensor of size
[batch_size]. Otherwise, a single scalar tensor is returned.
"""
# To ensure that dims of [2, 1] gets mapped to [2,]
weight = array_ops.squeeze(weight)
# If the weight is a scalar, its easy to compute:
if weight.get_shape().ndims == 0:
batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
[0], [1]), [])
num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)),
math_ops.to_float(batch_size))
num_per_batch = math_ops.select(math_ops.equal(weight, 0),
0.0, num_per_batch)
num_per_batch = math_ops.mul(array_ops.ones(
array_ops.reshape(batch_size, [1])), num_per_batch)
return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
# First, count the number of nonzero weights:
if weight.get_shape().ndims >= 1:
reduction_indices = range(1, weight.get_shape().ndims)
num_nonzero_per_batch = math_ops.reduce_sum(
math_ops.to_float(math_ops.not_equal(weight, 0)),
reduction_indices=reduction_indices)
# Next, determine the number of elements that weight would broadcast to:
broadcast_dims = array_ops.slice(array_ops.shape(losses),
[weight.get_shape().ndims], [-1])
num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims))
num_per_batch = math_ops.mul(num_nonzero_per_batch, num_to_broadcast)
return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
def absolute_difference(predictions, targets, weight=1.0, scope=None):
"""Adds an Absolute Difference loss to the training procedure.
`weight` acts as a coefficient for the loss. If a scalar is provided, then the
loss is simply scaled by the given value. If `weight` is a tensor of size
[batch_size], then the total loss for each sample of the batch is rescaled
by the corresponding element in the `weight` vector. If the shape of
`weight` matches the shape of `predictions`, then the loss of each
measurable element of `predictions` is scaled by the corresponding value of
`weight`.
Args:
predictions: The predicted outputs.
targets: The ground truth output tensor, same dimensions as 'predictions'.
weight: Coefficients for the loss a scalar, a tensor of shape
[batch_size] or a tensor whose shape matches `predictions`.
scope: The scope for the operations performed in computing the loss.
Returns:
A scalar `Tensor` representing the loss value.
Raises:
ValueError: If `x` has dimension 0.
ValueError: If the shape of `predictions` doesn't match that of `targets` or
if the shape of `weight` is invalid.
"""
return _reduce_batch(x, math_ops.reduce_sum, name)
with ops.op_scope([predictions, targets],
scope, "sum_of_squares_loss") as scope:
predictions.get_shape().assert_is_compatible_with(targets.get_shape())
if weight is None:
raise ValueError("`weight` cannot be None")
predictions = math_ops.to_float(predictions)
targets = math_ops.to_float(targets)
losses = math_ops.abs(math_ops.sub(predictions, targets))
return _compute_weighted_loss(losses, weight)
def _reduce_to_scalar(x, name=None):
"""Reduces losses to a scalar.
Given a tensor `x`, sums across all dimensions except dimension 0, then
average across dimension 0.
def sigmoid_cross_entropy(logits, multi_class_labels, weight=1.0,
label_smoothing=0, scope=None):
"""Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits.
Args:
x: A `Tensor` with dimension > 0.
name: A name for the operation (optional).
logits: [batch_size, num_classes] logits outputs of the network .
multi_class_labels: [batch_size, num_classes] target labels in (0, 1).
weight: Coefficients for the loss. The tensor must be a scalar, a tensor of
shape [batch_size] or shape [batch_size, num_classes].
label_smoothing: If greater than 0 then smooth the labels.
scope: The scope for the operations performed in computing the loss.
Returns:
Caculate sum of losses per example, then average across batch.
A scalar `Tensor` representing the loss value.
"""
with ops.op_scope([x], name, "scalar") as scope:
return math_ops.reduce_mean(_reduce_batch_sum(x), name=scope)
with ops.op_scope([logits, multi_class_labels],
scope, "sigmoid_cross_entropy_loss"):
return _cross_entropy(logits, multi_class_labels, weight,
label_smoothing,
activation_fn=nn.sigmoid_cross_entropy_with_logits)
def _validate_predicted_and_target(predicted, target):
# TODO(ptucker): Optionally add assert op for shape check, for cases when
# shape is not fully defined at graph construction time?
predicted.get_shape().assert_is_compatible_with(target.get_shape())
tensor_util.assert_same_float_dtype([predicted, target])
def softmax_cross_entropy(logits, onehot_labels, weight=1.0,
label_smoothing=0, scope=None):
"""Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits.
def _raw_absolute(predicted, target, name=None):
"""Computes and returns the per-example absolute loss.
Computes the per-example absolute value of the difference between
the target and predicted tensors. The tensors must have the same
shape.
It can scale the loss by weight factor, and smooth the labels.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
of predicted values.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`predicted` tensor.
name: A name for the operation (optional).
logits: [batch_size, num_classes] logits outputs of the network .
onehot_labels: [batch_size, num_classes] target one_hot_encoded labels.
weight: Coefficients for the loss. The tensor must be a scalar or a tensor
of shape [batch_size].
label_smoothing: If greater than 0 then smooth the labels.
scope: the scope for the operations performed in computing the loss.
Returns:
A `[batch_size, dim_1, ..., dim_n]` tensor of per-example absolute losses.
A scalar `Tensor` representing the loss value.
"""
with ops.op_scope([logits, onehot_labels],
scope, "softmax_cross_entropy_loss"):
return _cross_entropy(logits, onehot_labels, weight,
label_smoothing,
activation_fn=nn.softmax_cross_entropy_with_logits)
def _cross_entropy(logits, onehot_labels, weight, label_smoothing,
activation_fn):
"""Adds a CrossEntropyLoss to the losses collection.
`weight` acts as a coefficient for the loss. If a scalar is provided,
then the loss is simply scaled by the given value. If `weight` is a
tensor of size [`batch_size`], then the loss weights apply to each
corresponding sample.
Args:
logits: [batch_size, num_classes] logits outputs of the network .
onehot_labels: [batch_size, num_classes] target one_hot_encoded labels.
weight: Coefficients for the loss. If the activation is SIGMOID, then the
weight shape must be one of [1], [batch_size] or logits.shape().
Otherwise, the weight shape must be either [1] or [batch_size].
label_smoothing: If greater than 0 then smooth the labels.
activation_fn: The activation function to use. The method must take three
arguments, the logits, the labels, and an operation name.
Returns:
A scalar `Tensor` representing the loss value.
Raises:
ValueError: If `predicted` and `target` shapes do not match.
ValueError: If the shape of `predictions` doesn't match that of `targets` or
if the shape of `weight` is invalid or if `weight` is None.
"""
with ops.op_scope([predicted, target], name, "absolute_loss") as scope:
predicted = ops.convert_to_tensor(predicted, name="predicted")
target = ops.convert_to_tensor(target, name="target")
_validate_predicted_and_target(predicted, target)
return math_ops.abs(target - predicted, name=scope)
logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape())
if weight is None:
raise ValueError("`weight` cannot be None")
num_classes = onehot_labels.get_shape()[1]
onehot_labels = math_ops.cast(onehot_labels, logits.dtype)
if label_smoothing > 0:
smooth_positives = 1.0 - label_smoothing
smooth_negatives = label_smoothing / num_classes
onehot_labels = onehot_labels * smooth_positives + smooth_negatives
losses = activation_fn(logits, onehot_labels, name="xentropy")
return _compute_weighted_loss(losses, weight)
def _raw_squared(predicted, target, name=None):
"""Computes and returns the per-example squared loss, divided by 2.
def log(predictions, targets, weight=1.0, epsilon=1e-7, scope=None):
"""Adds a Log Loss term to the training procedure.
Computes the per-example squared difference between the target and
predicted tensors. The tensors must have the same shape.
`weight` acts as a coefficient for the loss. If a scalar is provided, then the
loss is simply scaled by the given value. If `weight` is a tensor of size
[batch_size], then the total loss for each sample of the batch is rescaled
by the corresponding element in the `weight` vector. If the shape of
`weight` matches the shape of `predictions`, then the loss of each
measurable element of `predictions` is scaled by the corresponding value of
`weight`.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
of predicted values.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`predicted` tensor.
name: A name for the operation (optional).
predictions: The predicted outputs.
targets: The ground truth output tensor, same dimensions as 'predictions'.
weight: Coefficients for the loss a scalar, a tensor of shape
[batch_size] or a tensor whose shape matches `predictions`.
epsilon: A small increment to add to avoid taking a log of zero.
scope: The scope for the operations performed in computing the loss.
Returns:
A `[batch_size, dim_1, ..., dim_n]` tensor of per-example squared losses.
A scalar `Tensor` representing the loss value.
Raises:
ValueError: If `predicted` and `target` shapes do not match.
ValueError: If the shape of `predictions` doesn't match that of `targets` or
if the shape of `weight` is invalid.
"""
with ops.op_scope([predicted, target], name, "squared_loss") as scope:
predicted = ops.convert_to_tensor(predicted, name="predicted")
target = ops.convert_to_tensor(target, name="target")
_validate_predicted_and_target(predicted, target)
return math_ops.div(math_ops.square(target - predicted), 2.0, name=scope)
with ops.op_scope([predictions, targets],
scope, "log_loss") as scope:
predictions.get_shape().assert_is_compatible_with(targets.get_shape())
if weight is None:
raise ValueError("`weight` cannot be None")
predictions = math_ops.to_float(predictions)
targets = math_ops.to_float(targets)
losses = -math_ops.mul(
targets,
math_ops.log(predictions + epsilon)) - math_ops.mul(
(1 - targets), math_ops.log(1 - predictions + epsilon))
return _compute_weighted_loss(losses, weight)
def absolute(predicted, target, name=None):
"""Reduces absolute losses to a scalar.
def sum_of_squares(predictions, targets, weight=1.0, scope=None):
"""Adds a Sum-of-Squares loss to the training procedure.
`weight` acts as a coefficient for the loss. If a scalar is provided, then the
loss is simply scaled by the given value. If `weight` is a tensor of size
[batch_size], then the total loss for each sample of the batch is rescaled
by the corresponding element in the `weight` vector. If the shape of
`weight` matches the shape of `predictions`, then the loss of each
measurable element of `predictions` is scaled by the corresponding value of
`weight`.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
of predicted values.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`predicted` tensor.
name: A name for the operation (optional).
predictions: The predicted outputs.
targets: The ground truth output tensor, same dimensions as 'predictions'.
weight: Coefficients for the loss a scalar, a tensor of shape
[batch_size] or a tensor whose shape matches `predictions`.
scope: The scope for the operations performed in computing the loss.
Returns:
Caculate sum of absolute losses per example, then average across batch.
"""
with ops.op_scope([predicted, target], name, "absolute_loss") as scope:
return _reduce_to_scalar(_raw_absolute(predicted, target), name=scope)
def squared(predicted, target, name=None):
"""Reduces squared losses to a scalar.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
of predicted values.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`predicted` tensor.
name: A name for the operation (optional).
Returns:
Caculate sum of squared losses per example, then average across batch.
"""
with ops.op_scope([predicted, target], name, "squared_loss") as scope:
return _reduce_to_scalar(_raw_squared(predicted, target), name=scope)
def logistic(logit, target, name=None):
"""Calculates the logistic cross-entropy loss, averaged across batches.
**WARNING:** `logit` must be unscaled.
See `tf.nn.sigmoid_cross_entropy_with_logits` for more details.
Args:
logit: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
of predicted logit values.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`logit` tensor.
name: A name for the operation (optional).
Returns:
A scalar `tensor` of the logistic cross-entropy loss, averaged across
batches.
A scalar `Tensor` representing the loss value.
Raises:
ValueError: If `logit` and `target` shapes do not match.
ValueError: If the shape of `predictions` doesn't match that of `targets` or
if the shape of `weight` is invalid.
"""
with ops.op_scope([logit, target], name, "logistic_loss") as scope:
return _reduce_to_scalar(
nn.sigmoid_cross_entropy_with_logits(logit, target), name=scope)
with ops.op_scope([predictions, targets],
scope, "sum_of_squares_loss") as scope:
predictions.get_shape().assert_is_compatible_with(targets.get_shape())
if weight is None:
raise ValueError("`weight` cannot be None")
predictions = math_ops.to_float(predictions)
targets = math_ops.to_float(targets)
losses = math_ops.square(math_ops.sub(predictions, targets))
return _compute_weighted_loss(losses, weight)
def softmax(logit, target, name=None):
"""Calculates the softmax cross-entropy loss, averaged across batches.
def sum_of_pairwise_squares(predictions, targets, weight=1.0, scope=None):
"""Adds a pairwise-errors-squared loss to the training procedure.
**WARNING:** `logit` must be unscaled, while the `target` should be a
normalized probability prediction.
See `tf.nn.softmax_cross_entropy_with_logits` for more details.
Unlike the sum_of_squares loss, which is a measure of the differences between
corresponding elements of `predictions` and `targets`, sum_of_pairwise_squares
is a measure of the differences between pairs of corresponding elements of
`predictions` and `targets`.
For example, if `targets`=[a, b, c] and `predictions`=[x, y, z], there are
three pairs of differences are summed to compute the loss:
loss = [ ((a-b) - (x-y)).^2 + ((a-c) - (x-z)).^2 + ((b-c) - (y-z)).^2 ] / 3
Note that since the inputs are of size [batch_size, d0, ... dN], the
corresponding pairs are computed within each batch sample but not across
samples within a batch. For example, if `predictions` represents a batch of
16 grayscale images of dimenion [batch_size, 100, 200], then the set of pairs
is drawn from each image, but not across images.
`weight` acts as a coefficient for the loss. If a scalar is provided, then the
loss is simply scaled by the given value. If `weight` is a tensor of size
[batch_size], then the total loss for each sample of the batch is rescaled
by the corresponding element in the `weight` vector.
Args:
logit: Tensor of actual values. Shape must have rank 2, generally
(batch, num_classes). num_classes must be > 1. For single-class
regression, use `logistic`. Type must be `tf.float32` or `tf.float64`.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`logit` tensor.
name: A name for the operation (optional).
predictions: The predicted outputs, a tensor of size [batch_size, d0, .. dN]
where N+1 is the total number of dimensions in `predictions`.
targets: The ground truth output tensor, whose shape must match the shape of
the `predictions` tensor.
weight: Coefficients for the loss a scalar, a tensor of shape [batch_size]
or a tensor whose shape matches `predictions`.
scope: The scope for the operations performed in computing the loss.
Returns:
A scalar `tensor` of the softmax cross-entropy loss, averaged across
batches.
A scalar `Tensor` representing the loss value.
Raises:
ValueError: If `logit` and `target` shapes do not match.
ValueError: If the shape of `predictions` doesn't match that of `targets` or
if the shape of `weight` is invalid.
"""
with ops.op_scope([logit, target], name, "softmax_loss") as scope:
shape = logit.get_shape().with_rank(2)
if shape.dims[1] and shape.dims[1] < 2:
raise ValueError(
"Invalid shape %s; use logistic() instead for only 1 class." %
shape)
return _reduce_to_scalar(
nn.softmax_cross_entropy_with_logits(logit, target), name=scope)
with ops.op_scope([predictions, targets],
scope, "sum_of_pairwise_squares_loss") as scope:
predictions.get_shape().assert_is_compatible_with(targets.get_shape())
if weight is None:
raise ValueError("`weight` cannot be None")
predictions = math_ops.to_float(predictions)
targets = math_ops.to_float(targets)
weight = math_ops.to_float(ops.convert_to_tensor(weight))
diffs = math_ops.sub(predictions, targets)
# Need to verify here since the function doesn't use _compute_weighted_loss
if diffs.get_shape().ndims is None:
raise ValueError("diffs.get_shape().ndims cannot be None")
if weight.get_shape().ndims is None:
raise ValueError("weight.get_shape().ndims cannot be None")
reduction_indices = range(1, diffs.get_shape().ndims)
sum_squares_diff_per_batch = math_ops.reduce_sum(
math_ops.square(diffs),
reduction_indices=reduction_indices)
num_present_per_batch = _num_present(diffs, weight, per_batch=True)
term1 = 2.0 * math_ops.div(sum_squares_diff_per_batch,
num_present_per_batch)
sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
term2 = 2.0 * math_ops.div(math_ops.square(sum_diff),
math_ops.square(num_present_per_batch))
loss = _scale_losses(term1 - term2, weight)
return math_ops.select(math_ops.reduce_sum(num_present_per_batch) > 0,
loss,
array_ops.zeros_like(loss),
name="value")
def cosine_distance(predictions, targets, dim, weight=1.0, scope=None):
"""Adds a cosine-distance loss to the training procedure.
Note that the function assumes that the predictions and targets are already
unit-normalized.
Args:
predictions: An arbitrary matrix.
targets: A `Tensor` whose shape matches 'predictions'
dim: The dimension along which the cosine distance is computed.
weight: Coefficients for the loss a scalar, a tensor of shape
[batch_size] or a tensor whose shape matches `predictions`.
scope: The scope for the operations performed in computing the loss.
Returns:
A scalar `Tensor` representing the loss value.
Raises:
ValueError: If predictions.shape doesn't match targets.shape, if the ignore
mask is provided and its shape doesn't match targets.shape or if
the ignore mask is not boolean valued.
"""
with ops.op_scope([predictions, targets],
scope, "cosine_distance_loss") as scope:
predictions.get_shape().assert_is_compatible_with(targets.get_shape())
if weight is None:
raise ValueError("`weight` cannot be None")
predictions = math_ops.to_float(predictions)
targets = math_ops.to_float(targets)
radial_diffs = math_ops.mul(predictions, targets)
losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,])
return _compute_weighted_loss(losses, weight)

View File

@ -23,250 +23,733 @@ import numpy as np
import tensorflow as tf
from tensorflow.contrib.framework.python.framework import tensor_util
pi = 3.14
indiana_pi = 3.2 # https://en.wikipedia.org/wiki/Indiana_Pi_Bill
class AbsoluteDifferenceLossTest(tf.test.TestCase):
class AbsoluteLossTest(tf.test.TestCase):
def setUp(self):
self._predictions = tf.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
self._targets = tf.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testAbsoluteLoss(self):
def testValueErrorThrownWhenWeightIsNone(self):
with self.test_session():
actual = tf.constant([pi], name="pi")
actual_placeholder = tf.placeholder(tf.float32)
label = tf.constant([indiana_pi], name="lbl")
label_placeholder = tf.placeholder(tf.float32, name="lbl_ph")
expected_loss = abs(indiana_pi - pi)
with self.assertRaises(ValueError):
tf.contrib.losses.absolute_difference(
self._predictions, self._predictions, weight=None)
# Both shapes are set.
both_shapes_loss = tf.contrib.losses.absolute(actual, label)
tf.initialize_all_variables().run()
np.testing.assert_almost_equal(
both_shapes_loss.eval(), expected_loss, decimal=6)
# No shape for 'actual' - check that the loss layer can be created.
no_actual_shape_loss = tf.contrib.losses.absolute(
actual_placeholder, label)
tf.initialize_all_variables().run()
np.testing.assert_almost_equal(
no_actual_shape_loss.eval({actual_placeholder: [pi]}),
expected_loss, decimal=6)
# No shape for 'label' - check that the loss layer can be created.
no_label_shape_loss = tf.contrib.losses.absolute(
actual, label_placeholder)
tf.initialize_all_variables().run()
np.testing.assert_almost_equal(
no_label_shape_loss.eval({label_placeholder: [indiana_pi]}),
expected_loss, decimal=6)
# No shapes.
no_shape_loss = tf.contrib.losses.absolute(
actual_placeholder, label_placeholder)
tf.initialize_all_variables().run()
np.testing.assert_almost_equal(
no_shape_loss.eval({label_placeholder: [indiana_pi],
actual_placeholder: [pi]}),
expected_loss, decimal=6)
# Evaluate the previous one again, but this time with different
# (matching) shapes. This should still work.
np.testing.assert_almost_equal(
no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi],
actual_placeholder: [pi, pi]}),
expected_loss, decimal=6)
class SquaredLossTest(tf.test.TestCase):
def testSquaredLoss(self):
def testAllCorrectNoLossWeight(self):
loss = tf.contrib.losses.absolute_difference(
self._predictions, self._predictions)
with self.test_session():
actual = tf.constant([pi], name="pi")
actual_placeholder = tf.placeholder(tf.float32)
label = tf.constant([indiana_pi], name="lbl")
label_placeholder = tf.placeholder(tf.float32, name="lbl_ph")
expected_loss = (indiana_pi - pi) * (indiana_pi - pi) / 2
self.assertAlmostEqual(0.0, loss.eval(), 3)
# Both shapes are set.
both_shapes_loss = tf.contrib.losses.squared(actual, label)
tf.initialize_all_variables().run()
np.testing.assert_almost_equal(
both_shapes_loss.eval(), expected_loss, decimal=6)
# No shape for 'actual' - check that the loss layer can be created.
no_actual_shape_loss = tf.contrib.losses.squared(
actual_placeholder, label)
tf.initialize_all_variables().run()
np.testing.assert_almost_equal(
no_actual_shape_loss.eval({actual_placeholder: [pi]}),
expected_loss, decimal=6)
# No shape for 'label' - check that the loss layer can be created.
no_label_shape_loss = tf.contrib.losses.squared(
actual, label_placeholder)
tf.initialize_all_variables().run()
np.testing.assert_almost_equal(
no_label_shape_loss.eval({label_placeholder: [indiana_pi]}),
expected_loss,
decimal=6)
# No shapes.
no_shape_loss = tf.contrib.losses.squared(
actual_placeholder, label_placeholder)
tf.initialize_all_variables().run()
np.testing.assert_almost_equal(
no_shape_loss.eval({label_placeholder: [indiana_pi],
actual_placeholder: [pi]}),
expected_loss, decimal=6)
# Evaluate the previous one again, but this time with different
# (matching) shapes. This should still work.
np.testing.assert_almost_equal(
no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi],
actual_placeholder: [pi, pi]}),
expected_loss, decimal=6)
class LogisticTest(tf.test.TestCase):
def _expected_loss(self, logit, target):
sigmoid = 1.0 / (1.0 + np.exp(-logit))
logistic_loss = (target * -np.log(sigmoid)) - (
(1.0 - target) * np.log(1.0 - sigmoid))
batch_losses = np.sum(logistic_loss, 1)
return np.sum(batch_losses) / len(batch_losses)
def testSimple(self):
logit = np.array([[9.45, -42], [4.2, 1], [-0.6, 20]])
target = np.array([[0.8, 0.9], [0.45, 0.99999], [0.1, 0.0006]])
def testNonZeroLoss(self):
loss = tf.contrib.losses.absolute_difference(
self._predictions, self._targets)
with self.test_session():
loss = tf.contrib.losses.logistic(tf.constant(logit), tf.constant(target))
self.assertAllClose(self._expected_loss(logit, target), loss.eval())
self.assertAlmostEqual(5.5, loss.eval(), 3)
def testComplex(self):
def testNonZeroLossWithPythonScalarWeight(self):
weight = 2.3
loss = tf.contrib.losses.absolute_difference(
self._predictions, self._targets, weight)
with self.test_session():
# [batch] and [batch,1] work the same.
loss3x0 = tf.contrib.losses.logistic(
tf.constant([-1.0, 3.0, -3.0]),
tf.constant([0.3, 0.1, 0.4]))
tf.initialize_all_variables().run()
self.assertAllClose(1.536812, loss3x0.eval())
self.assertAlmostEqual(5.5 * weight, loss.eval(), 3)
expected_loss = 1.536812
actual3x1 = [[-1.0], [3.0], [-3.0]]
label3x1 = [[0.3], [0.1], [0.4]]
loss3x1 = tf.contrib.losses.logistic(
tf.constant(actual3x1), tf.constant(label3x1))
tf.initialize_all_variables().run()
self.assertAllClose(expected_loss, loss3x1.eval())
# Batch average stays the same with repeats of the same examples.
loss9x1 = tf.contrib.losses.logistic(
tf.constant(actual3x1 * 3), tf.constant(label3x1 * 3))
tf.initialize_all_variables().run()
self.assertAllClose(expected_loss, loss9x1.eval())
# Loss stays the same when adding another class with 0 loss.
loss3x2 = tf.contrib.losses.logistic(
tf.constant([[-1.0, 100.0], [3.0, -100.0], [-3.0, -100.0]]),
tf.constant([[0.3, 1.0], [0.1, 0.0], [0.4, 0.0]]))
tf.initialize_all_variables().run()
self.assertAllClose(expected_loss, loss3x2.eval())
# Loss stays the same with additional x1 dimension.
loss3x1x2 = tf.contrib.losses.logistic(
tf.constant([[[-1.0, 100.0]], [[3.0, -100.0]], [[-3.0, -100.0]]]),
tf.constant([[[0.3, 1.0]], [[0.1, 0.0]], [[0.4, 0.0]]]))
tf.initialize_all_variables().run()
self.assertAllClose(expected_loss, loss3x1x2.eval())
# We have set one label value to be out of range (the -0.4) and
# expect the absence of a crash since we did not set validate=True
loss = tf.contrib.losses.logistic(
tf.constant([[[-1.0, 100.0]], [[3.0, -100.0]], [[-3.0, -100.0]]]),
tf.constant([[[0.3, 1.0]], [[0.1, 0.0]], [[-0.4, 0.0]]]))
tf.initialize_all_variables().run()
loss.eval()
def testLogisticVsSoftmax(self):
def testNonZeroLossWithScalarTensorWeight(self):
weight = 2.3
loss = tf.contrib.losses.absolute_difference(
self._predictions, self._targets, tf.constant(weight))
with self.test_session():
# Each logit = L and target = T used for logistic_loss corresponds to
# logits [a, b] where a - b = L and targets [T, 1 - T] for
# softmax_loss.
self.assertAlmostEqual(5.5 * weight, loss.eval(), 3)
expected_loss = (0.69314718 + 1.01326168 + 2.10692811) / 3.0
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weight = tf.constant([1.2, 0.0], shape=[2,])
loss = tf.contrib.losses.absolute_difference(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
logistic_loss = tf.contrib.losses.logistic(
tf.constant([0.0, 1.0, 2.0]),
tf.constant([0.5, 0.3, 0.01]))
tf.initialize_all_variables().run()
self.assertAllClose(expected_loss, logistic_loss.eval())
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weight = tf.constant([1.2, 0.0], shape=[2, 1])
loss = tf.contrib.losses.absolute_difference(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
softmax_loss = tf.contrib.losses.softmax(
tf.constant([[1.0, 1.0], [2.0, 1.0], [3.0, 1.0]]),
tf.constant([[0.5, 0.5], [0.3, 0.7], [0.01, 0.99]]))
tf.initialize_all_variables().run()
self.assertAllClose(expected_loss, softmax_loss.eval())
def testNonZeroLossWithSampleSpecificWeights(self):
weight = tf.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = tf.contrib.losses.absolute_difference(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(16.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weight = tf.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = tf.contrib.losses.absolute_difference(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(6.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weight = tf.zeros((2, 3))
loss = tf.contrib.losses.absolute_difference(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
class SoftmaxTest(tf.test.TestCase):
class SoftmaxCrossEntropyLossTest(tf.test.TestCase):
def testNoneWeightRaisesValueError(self):
logits = tf.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = tf.constant([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
with self.test_session():
with self.assertRaises(ValueError):
tf.contrib.losses.softmax_cross_entropy(logits, labels, weight=None)
def testAllCorrect(self):
with self.test_session():
logits = tf.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = tf.constant([[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]])
loss = tf.contrib.losses.softmax(logits, labels)
labels = tf.constant([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testAllWrong(self):
logits = tf.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = tf.constant([[0, 0, 1],
[1, 0, 0],
[0, 1, 0]])
with self.test_session():
logits = tf.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = tf.constant([[0.0, 0.0, 1.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0]])
loss = tf.contrib.losses.softmax(logits, labels)
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
def testSoftmax(self):
def testNonZeroLossWithPythonScalarWeight(self):
logits = tf.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = tf.constant([[0, 0, 1],
[1, 0, 0],
[0, 1, 0]])
weight = 2.3
with self.test_session():
# [batch] and [batch,1] fail, softmax_loss is only for multiclass.
self.assertRaisesRegexp(
ValueError, "must have rank 2", tf.contrib.losses.softmax,
tf.constant([-100.0, 10.0, 0.0]),
tf.constant([1.0, 1.0, 1.0]))
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
self.assertAlmostEqual(loss.eval(), weight * 10.0, 3)
self.assertRaisesRegexp(
ValueError, "only 1 class", tf.contrib.losses.softmax,
tf.constant([[-100.0], [10.0], [0.0]]),
tf.constant([[1.0], [1.0], [1.0]]))
def testNonZeroLossWithScalarTensorWeight(self):
logits = tf.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = tf.constant([[0, 0, 1],
[1, 0, 0],
[0, 1, 0]])
weight = 2.3
with self.test_session():
loss = tf.contrib.losses.softmax_cross_entropy(
logits, labels, tf.constant(weight))
self.assertAlmostEqual(loss.eval(), weight * 10.0, 3)
expected_loss = 3.173363
loss3x2 = tf.contrib.losses.softmax(
tf.constant([[-1.0, 1.0], [0.0, 0.0], [10.0, -1.0]]),
tf.constant([[0.5, 0.5], [0.3, 0.7], [0.3, 0.7]]))
tf.initialize_all_variables().run()
self.assertAllClose(expected_loss, loss3x2.eval())
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
logits = tf.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = tf.constant([[0, 0, 1],
[1, 0, 0],
[0, 1, 0]])
weight = tf.constant([1.2, 3.4, 5.6], shape=[3])
with self.test_session():
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
self.assertAlmostEqual(loss.eval(), (1.2 + 3.4 + 5.6) * 10.0 / 3.0, 3)
# Loss stays the same when adding another negative class.
loss3x3 = tf.contrib.losses.softmax(
tf.constant(
[[-1.0, 1.0, -100.0], [0.0, 0.0, -100.0], [10.0, -1.0, -100.0]]),
tf.constant([[0.5, 0.5, 0.0], [0.3, 0.7, 0.0], [0.3, 0.7, 0.0]]))
tf.initialize_all_variables().run()
self.assertAllClose(expected_loss, loss3x3.eval())
def testAllWrongAllMissing(self):
logits = tf.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = tf.constant([[0, 0, 1],
[1, 0, 0],
[0, 1, 0]])
weight = tf.constant([0, 0, 0], shape=[3])
with self.test_session():
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
self.assertAlmostEqual(loss.eval(), 0.0, 3)
# Fails for rank > 2.
self.assertRaisesRegexp(
ValueError, "must have rank 2", tf.contrib.losses.softmax,
tf.constant([[[-1.0, 1.0]], [[0.0, 0.0]], [[10.0, -1.0]]]),
tf.constant([[[0.5, 0.5]], [[0.3, 0.7]], [[0.3, 0.7]]]))
def testSomeMissing(self):
logits = tf.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = tf.constant([[0, 0, 1],
[1, 0, 0],
[0, 1, 0]])
weight = tf.constant([1.2, 0, 0], shape=[3])
with self.test_session():
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
self.assertAlmostEqual(loss.eval(), 12.0, 3)
def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self):
with self.test_session():
logits = tf.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
labels = tf.constant([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
weight = tf.constant([[3, 4, 5],
[2, 6, 0],
[8, 0, 1]])
with self.assertRaises(ValueError):
tf.contrib.losses.softmax_cross_entropy(
logits, labels, weight=weight).eval()
if __name__ == "__main__":
class SigmoidCrossEntropyLossTest(tf.test.TestCase):
def testAllCorrectSigmoid(self):
with self.test_session():
logits = tf.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
labels = tf.constant([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
loss = tf.contrib.losses.sigmoid_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testAllWrongSigmoid(self):
with self.test_session():
logits = tf.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
labels = tf.constant([[0, 0, 1],
[1, 0, 0],
[0, 1, 0]])
loss = tf.contrib.losses.sigmoid_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
with self.test_session():
logits = tf.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
labels = tf.constant([[0, 0, 1],
[1, 0, 0],
[0, 1, 0]])
weight = tf.constant([[3, 4, 5],
[2, 6, 0],
[8, 0, 1]])
loss = tf.contrib.losses.sigmoid_cross_entropy(
logits, labels, weight=weight)
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 1700.0 / 7.0, 3)
def testMultiCorrectSigmoid(self):
logits = tf.constant([[100.0, -100.0, 100.0],
[100.0, 100.0, -100.0],
[-100.0, 100.0, 100.0]])
labels = tf.constant([[1, 0, 1],
[1, 1, 0],
[0, 1, 1]])
loss = tf.contrib.losses.sigmoid_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
with self.test_session():
self.assertAlmostEqual(loss.eval(), 0.0, 3)
class LogLossTest(tf.test.TestCase):
def setUp(self):
predictions = np.asarray([.9, .2, .2, .8, .4, .6]).reshape((2, 3))
targets = np.asarray([1.0, 0.0, 1.0, 1.0, 0.0, 0.0]).reshape((2, 3))
self._np_predictions = predictions
self._np_targets = targets
epsilon = 1e-7
self._expected_losses = np.multiply(
targets, np.log(predictions + epsilon)) + np.multiply(
1 - targets, np.log(1 - predictions + epsilon))
self._predictions = tf.constant(predictions)
self._targets = tf.constant(targets)
def testValueErrorThrownWhenWeightIsNone(self):
with self.test_session():
with self.assertRaises(ValueError):
tf.contrib.losses.log(self._targets, self._targets, weight=None)
def testAllCorrectNoLossWeight(self):
loss = tf.contrib.losses.log(self._targets, self._targets)
with self.test_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testAllCorrectNoLossWeightWithPlaceholder(self):
tf_predictions = tf.placeholder(tf.float32, shape=self._np_targets.shape)
loss = tf.contrib.losses.log(tf_predictions, self._targets)
with self.test_session():
self.assertAlmostEqual(0.0, loss.eval(feed_dict={
tf_predictions: self._np_targets}), 3)
def testNonZeroLoss(self):
loss = tf.contrib.losses.log(self._predictions, self._targets)
with self.test_session():
self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weight = 2.3
loss = tf.contrib.losses.log(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(weight * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weight = 2.3
loss = tf.contrib.losses.log(
self._predictions, self._targets, tf.constant(weight))
with self.test_session():
self.assertAlmostEqual(weight * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeightAndPlaceholder(self):
tf_predictions = tf.placeholder(tf.float32,
shape=self._np_predictions.shape)
weight = 2.3
loss = tf.contrib.losses.log(
tf_predictions, self._targets, tf.constant(weight))
with self.test_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weight * -np.sum(self._expected_losses) / 6.0,
loss, 3)
def testNonZeroLossWithScalarTensorWeightAndPlaceholderWithRankOnly(self):
tf_predictions = tf.placeholder(tf.float32, shape=[None, None])
weight = 2.3
loss = tf.contrib.losses.log(
tf_predictions, self._targets, tf.constant(weight))
with self.test_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weight * -np.sum(self._expected_losses) / 6.0,
loss, 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weight = tf.constant([1.2, 3.4], shape=[2])
expectedes = np.multiply(
self._expected_losses,
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
loss = tf.contrib.losses.log(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(-np.sum(expectedes) / 6.0,
loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self):
weight = tf.constant([1.2, 0], shape=[2])
expectedes = np.multiply(
self._expected_losses,
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape((2, 3)))
loss = tf.contrib.losses.log(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(-np.sum(expectedes) / 3.0,
loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self):
weight = tf.constant([1.2, 0], shape=[2, 1])
expectedes = np.multiply(
self._expected_losses,
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape((2, 3)))
loss = tf.contrib.losses.log(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(-np.sum(expectedes) / 3.0,
loss.eval(), 3)
def testWeightsWithSameNumDimsButWrongShapeThrowsException(self):
weight = tf.constant(np.random.normal(size=(2, 4)), shape=[2, 4])
with self.test_session():
with self.assertRaises(ValueError):
tf.contrib.losses.log(self._predictions, self._targets, weight)
def testNonZeroLossWithMeasurementSpecificWeights(self):
weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
expectedes = np.multiply(self._expected_losses, weight)
loss = tf.contrib.losses.log(
self._predictions,
self._targets,
weight=tf.constant(weight, shape=(2, 3)))
with self.test_session():
self.assertAlmostEqual(-np.sum(expectedes) / 5.0, loss.eval(), 3)
def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
expectedes = np.multiply(self._expected_losses, weight)
tf_predictions = tf.placeholder(tf.float32, shape=[2, 3])
loss = tf.contrib.losses.log(
tf_predictions,
self._targets,
weight=tf.constant(weight, shape=(2, 3)))
with self.test_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expectedes) / 5.0, loss, 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weight = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3))
expectedes = np.multiply(self._expected_losses, weight)
loss = tf.contrib.losses.log(
self._predictions,
self._targets,
weight=tf.constant(weight, shape=(2, 3)))
with self.test_session():
self.assertAlmostEqual(-np.sum(expectedes), loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
weight = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3))
expectedes = np.multiply(self._expected_losses, weight)
tf_predictions = tf.placeholder(tf.float32, shape=[2, 3])
tf_weight = tf.constant(weight, shape=(2, 3))
loss = tf.contrib.losses.log(tf_predictions, self._targets, tf_weight)
with self.test_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expectedes), loss, 3)
def testLossWithSampleSpecificWeightsAllZero(self):
tf_weight = tf.zeros(shape=(2, 3))
loss = tf.contrib.losses.log(
self._predictions, self._targets, tf_weight)
with self.test_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
class SumOfSquaresLossTest(tf.test.TestCase):
def setUp(self):
self._predictions = tf.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
self._targets = tf.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
with self.test_session():
with self.assertRaises(ValueError):
tf.contrib.losses.sum_of_squares(
self._predictions, self._predictions, weight=None)
def testAllCorrectNoLossWeight(self):
loss = tf.contrib.losses.sum_of_squares(
self._predictions, self._predictions)
with self.test_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = tf.contrib.losses.sum_of_squares(
self._predictions, self._targets)
with self.test_session():
self.assertAlmostEqual(49.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weight = 2.3
loss = tf.contrib.losses.sum_of_squares(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(49.5 * weight, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weight = 2.3
loss = tf.contrib.losses.sum_of_squares(
self._predictions, self._targets, tf.constant(weight))
with self.test_session():
self.assertAlmostEqual(49.5 * weight, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weight = tf.constant([1.2, 3.4], shape=[2,])
loss = tf.contrib.losses.sum_of_squares(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weight = tf.constant([1.2, 3.4], shape=[2, 1])
loss = tf.contrib.losses.sum_of_squares(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weight = tf.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = tf.contrib.losses.sum_of_squares(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(587 / 5.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weight = tf.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = tf.contrib.losses.sum_of_squares(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(18.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weight = tf.zeros((2, 3))
loss = tf.contrib.losses.sum_of_squares(
self._predictions, self._targets, weight)
with self.test_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
class SumOfPairwiseSquaresLossTest(tf.test.TestCase):
def setUp(self):
self._predictions = np.array([[4, 8, 12],
[8, 1, 3]])
self._targets = np.array([[1, 9, 2],
[-5, -5, 7]])
batch_size, dims = self._targets.shape
# Compute the expected loss 'manually'.
total = np.zeros((batch_size, 1))
for b in range(batch_size):
for i in range(dims):
for j in range(dims):
x = self._predictions[b, i].item() - self._predictions[b, j].item()
y = self._targets[b, i].item() - self._targets[b, j].item()
tmp = (x-y) * (x-y)
total[b] += tmp
self._expected_losses = np.divide(total, 9.0)
def testValueErrorThrownWhenWeightIsNone(self):
with self.test_session():
with self.assertRaises(ValueError):
tf.contrib.losses.sum_of_pairwise_squares(
predictions=tf.constant(self._targets),
targets=tf.constant(self._targets),
weight=None)
def testAllCorrectNoLossWeight(self):
loss = tf.contrib.losses.sum_of_pairwise_squares(
predictions=tf.constant(self._targets),
targets=tf.constant(self._targets))
with self.test_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = tf.contrib.losses.sum_of_pairwise_squares(
predictions=tf.constant(self._predictions),
targets=tf.constant(self._targets))
with self.test_session():
self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weight = 2.3
loss = tf.contrib.losses.sum_of_pairwise_squares(
predictions=tf.constant(self._predictions),
targets=tf.constant(self._targets),
weight=weight)
with self.test_session():
self.assertAlmostEqual(weight * np.sum(self._expected_losses),
loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weight = 2.3
loss = tf.contrib.losses.sum_of_pairwise_squares(
predictions=tf.constant(self._predictions),
targets=tf.constant(self._targets),
weight=tf.constant(weight))
with self.test_session():
self.assertAlmostEqual(weight * np.sum(self._expected_losses),
loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self):
weight = 2.3
tf_predictions = tf.placeholder(tf.float32, shape=self._predictions.shape)
tf_targets = tf.placeholder(tf.float32, shape=self._targets.shape)
loss = tf.contrib.losses.sum_of_pairwise_squares(
predictions=tf_predictions,
targets=tf_targets,
weight=tf.constant(weight))
with self.test_session() as sess:
loss = sess.run(loss, feed_dict={
tf_predictions: self._predictions,
tf_targets: self._targets,
})
self.assertAlmostEqual(weight * np.sum(self._expected_losses), loss, 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weight = np.asarray([2.0, 1.0]).reshape((2, 1))
expectedes = np.multiply(weight, self._expected_losses)
loss = tf.contrib.losses.sum_of_pairwise_squares(
predictions=tf.constant(self._predictions),
targets=tf.constant(self._targets),
weight=tf.constant(weight, shape=[2]))
with self.test_session():
self.assertAlmostEqual(np.sum(expectedes), loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self):
weight = np.asarray([1.2, 3.4]).reshape((2, 1))
expectedes = np.multiply(weight, self._expected_losses)
tf_predictions = tf.placeholder(tf.float32, shape=self._predictions.shape)
tf_targets = tf.placeholder(tf.int32, shape=self._targets.shape)
loss = tf.contrib.losses.sum_of_pairwise_squares(
predictions=tf_predictions,
targets=tf_targets,
weight=tf.constant(weight, shape=[2]))
with self.test_session() as sess:
loss = sess.run(loss, feed_dict={
tf_predictions: self._predictions,
tf_targets: self._targets,
})
self.assertAlmostEqual(np.sum(expectedes), loss, 3)
def testLossWithAllZeroBatchSpecificWeights(self):
weight = np.zeros((2, 1))
loss = tf.contrib.losses.sum_of_pairwise_squares(
predictions=tf.constant(self._predictions),
targets=tf.constant(self._targets),
weight=tf.constant(weight, shape=[2]))
with self.test_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
class CosineDistanceLossTest(tf.test.TestCase):
def setUp(self):
self._predictions = np.asarray([[1, 0, 0], # Batch 1
[0, 0, -1],
[1, 0, 0], # Batch 2
[1, 0, 0],
[0, 0, -1], # Batch 3
[1, 0, 0]]).reshape((3, 2, 3))
self._targets = np.asarray([[1, 0, 0],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
[0, 1, 0]]).reshape((3, 2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
with self.test_session():
with self.assertRaises(ValueError):
tf.contrib.losses.cosine_distance(
predictions=tf.constant(self._targets),
targets=tf.constant(self._targets),
dim=2,
weight=None)
def testAllCorrectNoWeights(self):
loss = tf.contrib.losses.cosine_distance(
predictions=tf.constant(self._targets),
targets=tf.constant(self._targets),
dim=2)
with self.test_session():
self.assertAlmostEqual(0, loss.eval(), 5)
def testPartiallyCorrectWithIntegerValues(self):
loss = tf.contrib.losses.cosine_distance(
predictions=tf.constant(self._predictions),
targets=tf.constant(self._targets),
dim=2)
with self.test_session():
self.assertAlmostEqual(1, loss.eval(), 5)
def testPartiallyCorrectFloatingPointValues(self):
predictions = np.matrix((
'0.819031913261206 0.567041924552012 0.087465312324590;'
'-0.665139432070255 -0.739487441769973 -0.103671883216994;'
'0.707106781186548 -0.707106781186548 0'))
targets = np.matrix((
'0.819031913261206 0.567041924552012 0.087465312324590;'
'0.665139432070255 0.739487441769973 0.103671883216994;'
'0.707106781186548 0.707106781186548 0'))
tf_preds = tf.constant(predictions, shape=(3, 1, 3), dtype=tf.float32)
tf_targets = tf.constant(targets, shape=(3, 1, 3), dtype=tf.float32)
loss = tf.contrib.losses.cosine_distance(tf_preds, tf_targets, dim=2)
with self.test_session():
self.assertAlmostEqual(1.0, loss.eval(), 5)
def testSampleSpecificWeights(self):
loss = tf.contrib.losses.cosine_distance(
predictions=tf.constant(self._predictions),
targets=tf.constant(self._targets),
dim=2,
weight=tf.constant([1, 0, 0]))
with self.test_session():
self.assertEqual(1.0, loss.eval())
def testMeasurementSpecificWeights(self):
loss = tf.contrib.losses.cosine_distance(
predictions=tf.constant(self._predictions),
targets=tf.constant(self._targets),
dim=2,
weight=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2)))
with self.test_session():
self.assertEqual(3.0 / 4.0, loss.eval())
def testValueErrorThrownWithShapelessPlaceholder(self):
tf_predictions = tf.placeholder(tf.float32)
with self.test_session():
with self.assertRaises(ValueError):
tf.contrib.losses.cosine_distance(
predictions=tf_predictions,
targets=tf.constant(self._targets),
dim=2,
weight=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2)))
def testMeasurementSpecificWeightsWithPlaceholderWithShape(self):
tf_predictions = tf.placeholder(tf.float32, shape=self._targets.shape)
loss = tf.contrib.losses.cosine_distance(
predictions=tf_predictions,
targets=tf.constant(self._targets),
dim=2,
weight=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2)))
with self.test_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._predictions})
self.assertEqual(3.0 / 4.0, loss)
def testZeroLossWhenAllSampleSpecificWeightsAreZero(self):
loss = tf.contrib.losses.cosine_distance(
predictions=tf.constant(self._predictions),
targets=tf.constant(self._targets),
dim=2,
weight=tf.zeros((3,)))
with self.test_session():
self.assertEqual(0, loss.eval())
def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self):
loss = tf.contrib.losses.cosine_distance(
predictions=tf.constant(self._predictions),
targets=tf.constant(self._targets),
dim=2,
weight=tf.zeros((3, 2)))
with self.test_session():
self.assertEqual(0, loss.eval())
if __name__ == '__main__':
tf.test.main()