Adding several loss functions including cosine_distance_loss, log_loss, softmax_cross_entropy_loss, and sum_of_pairwise_squares_loss. Refactoring old losses for consistency.
Change: 120130871
This commit is contained in:
parent
c2d9cb1d08
commit
ec89b0c218
@ -18,5 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import *
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import absolute_difference
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import cosine_distance
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import log
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import sigmoid_cross_entropy
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import softmax_cross_entropy
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import sum_of_pairwise_squares
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import sum_of_squares
|
||||
|
@ -14,272 +14,552 @@
|
||||
# ==============================================================================
|
||||
"""## Loss operations for use in neural networks.
|
||||
|
||||
The loss ops measure error for use in neural networks. These losses
|
||||
can be used for measuring accuracy of a network in a regression task
|
||||
or for regularization purposes (e.g., weight decay).
|
||||
All of the loss functions take a pair of predictions and ground truth labels,
|
||||
from which the loss is computed. It is assumed that the shape of both these
|
||||
tensors is of the form [batch_size, d1, ... dN] where `batch_size` is the number
|
||||
of samples in the batch and `d1` ... `dN` are the remaining dimensions.
|
||||
|
||||
These loss ops are, by design, minimal, enabling flexibility in how
|
||||
their output can be used.
|
||||
It is common, when training with multiple loss functions, to adjust the relative
|
||||
strengths of individual losses. This is performed by rescaling the losses via
|
||||
a `weight` parameter passed to the loss functions. For example, if we were
|
||||
training with both log_loss and sum_of_squares_loss, and we wished that the
|
||||
log_loss penalty be twice as severe as the sum_of_squares_loss, we would
|
||||
implement this as:
|
||||
|
||||
@@absolute
|
||||
@@squared
|
||||
@@logistic
|
||||
@@softmax
|
||||
# Explicitely set the weight.
|
||||
tf.contrib.losses.log(predictions, targets, weight=2.0)
|
||||
|
||||
# Uses default weight of 1.0
|
||||
tf.contrib.losses.sum_of_squares(predictions, targets)
|
||||
|
||||
While specifying a scalar loss rescales the loss over the entire batch,
|
||||
we sometimes want to rescale the loss per batch sample. For example, if we have
|
||||
certain examples that matter more to us to get correctly, we might want to have
|
||||
a higher loss that other samples whose mistakes matter less. In this case, we
|
||||
can provide a weight vector of length `batch_size` which results in the loss
|
||||
for each sample in the batch being scaled by the corresponding weight element.
|
||||
For example, consider the case of a classification problem where we want to
|
||||
maximize our accuracy but we especially interested in obtaining high accuracy
|
||||
for a specific class:
|
||||
|
||||
inputs, labels = LoadData(batch_size=3)
|
||||
logits = MyModelPredictions(inputs)
|
||||
|
||||
# Ensures that the loss for examples whose ground truth class is `3` is 5x
|
||||
# higher than the loss for all other examples.
|
||||
weight = tf.mul(4, tf.cast(tf.equal(labels, 3), tf.float32)) + 1
|
||||
|
||||
onehot_labels = tf.one_hot(labels, num_classes=5)
|
||||
tf.contrib.losses.softmax_cross_entropy(logits, onehot_labels, weight=weight)
|
||||
|
||||
Finally, in certain cases, we may want to specify a different loss for every
|
||||
single measurable value. For example, if we are performing per-pixel depth
|
||||
prediction, or per-pixel denoising, a single batch sample has P values where P
|
||||
is the number of pixels in the image. For many losses, the number of measurable
|
||||
values matches the number of elements in the predictions and targets tensors.
|
||||
For others, such as softmax_cross_entropy and cosine_distance, the
|
||||
loss functions reduces the dimensions of the inputs to produces a tensor of
|
||||
losses for each measurable value. For example, softmax_cross_entropy takes as
|
||||
input predictions and labels of dimension [batch_size, num_classes] but the
|
||||
number of measurable values is [batch_size]. Consequently, when passing a weight
|
||||
tensor to specify a different loss for every measurable value, the dimension of
|
||||
the tensor will depend on the loss being used.
|
||||
|
||||
For a concrete example, consider the case of per-pixel depth prediction where
|
||||
certain ground truth depth values are missing (due to sensor noise in the
|
||||
capture process). In this case, we want to assign zero weight to losses for
|
||||
these predictions.
|
||||
|
||||
# 'depths' that are missing have a value of 0:
|
||||
images, depths = LoadData(...)
|
||||
predictions = MyModelPredictions(images)
|
||||
|
||||
weight = tf.cast(tf.greater(depths, 0), tf.float32)
|
||||
tf.contrib.losses.sum_of_squares(predictions, depths, weight)
|
||||
|
||||
Note that when using weights for the losses, the final average is computed
|
||||
by rescaling the losses by the weights and then dividing by the total number of
|
||||
non-zero samples. For an arbitrary set of weights, this may not necessarily
|
||||
produce a weighted average. Instead, it simply and transparently rescales the
|
||||
per-element losses before averaging over the number of observations. For example
|
||||
if the losses computed by the loss function is an array [4, 1, 2, 3] and the
|
||||
weights are an array [1, 0.5, 3, 9], then the average loss is:
|
||||
|
||||
(4*1 + 1*0.5 + 2*3 + 3*9) / 4
|
||||
|
||||
However, with a single loss function and an arbitrary set of weights, one can
|
||||
still easily create a loss function such that the resulting loss is a
|
||||
weighted average over the individual prediction errors:
|
||||
|
||||
images, labels = LoadData(...)
|
||||
predictions = MyModelPredictions(images)
|
||||
|
||||
weight = MyComplicatedWeightingFunction(labels)
|
||||
weight = tf.div(weight, tf.size(weight))
|
||||
tf.contrib.losses.sum_of_squares(predictions, depths, weight)
|
||||
|
||||
|
||||
@@absolute_difference
|
||||
@@cosine_distance
|
||||
@@log
|
||||
@@sigmoid_cross_entropy
|
||||
@@softmax_cross_entropy
|
||||
@@sum_of_pairwise_squares
|
||||
@@sum_of_squares
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
|
||||
__all__ = ["absolute", "squared", "logistic", "softmax"]
|
||||
__all__ = [
|
||||
"absolute_difference",
|
||||
"cosine_distance",
|
||||
"log",
|
||||
"sigmoid_cross_entropy",
|
||||
"softmax_cross_entropy",
|
||||
"sum_of_pairwise_squares",
|
||||
"sum_of_squares",
|
||||
]
|
||||
|
||||
|
||||
def _reduce_batch(x, reduce_fn, name=None):
|
||||
"""Given a tensor `x`, calls reduce_fn to reduce it across dimensions.
|
||||
|
||||
Given a tensor with number of dimensions > 1, _reduce_batch will reduce the
|
||||
tensor across all dimensions except for dimension 0. As an example, given a
|
||||
tensor of shape [batch_size, d1, d2], this function will reduce across
|
||||
dimensions d1 and d2, returning a tensor of shape [batch_size].
|
||||
|
||||
Tensors of dimension 1 are returned as-is, while tensors of dimension 0
|
||||
raise a ValueError.
|
||||
def _scale_losses(losses, weight):
|
||||
"""Computes the scaled loss.
|
||||
|
||||
Args:
|
||||
x: A `Tensor` with dimension > 0.
|
||||
reduce_fn: A math_ops reduce function that takes arguments of
|
||||
`x`, `reduction_indices`, and `name`.
|
||||
name: A name for the operation (optional).
|
||||
losses: A `Tensor` of size [batch_size, d1, ... dN].
|
||||
weight: A `Tensor` of size [1], [batch_size] or [batch_size, d1, ... dN].
|
||||
The `losses` are reduced (tf.reduce_sum) until its dimension matches
|
||||
that of `weight` at which point the reduced `losses` are element-wise
|
||||
multiplied by `weight` and a final reduce_sum is computed on the result.
|
||||
Conceptually, this operation is equivalent to broadcasting (tiling)
|
||||
`weight` to be the same size as `losses`, performing an element-wise
|
||||
multiplication, and summing the result.
|
||||
|
||||
Returns:
|
||||
A `Tensor` with values reduced by reduce_fn across all dimensions > 0.
|
||||
A scalar tf.float32 `Tensor` whose value represents the sum of the scaled
|
||||
`losses`.
|
||||
"""
|
||||
# First, compute the sum of the losses over all elements:
|
||||
start_index = max(0, weight.get_shape().ndims)
|
||||
reduction_indices = range(start_index, losses.get_shape().ndims)
|
||||
reduced_losses = math_ops.reduce_sum(losses,
|
||||
reduction_indices=reduction_indices)
|
||||
reduced_losses = math_ops.mul(reduced_losses, weight)
|
||||
return math_ops.reduce_sum(reduced_losses)
|
||||
|
||||
|
||||
def _safe_mean(losses, num_present):
|
||||
"""Computes a safe mean of the losses.
|
||||
|
||||
Args:
|
||||
losses: A tensor whose elements contain individual loss measurements.
|
||||
num_present: The number of measurable losses in the tensor.
|
||||
|
||||
Returns:
|
||||
A scalar representing the mean of the losses. If `num_present` is zero,
|
||||
then zero is returned.
|
||||
"""
|
||||
total_loss = math_ops.reduce_sum(losses)
|
||||
return math_ops.select(num_present > 0,
|
||||
math_ops.div(total_loss, num_present),
|
||||
array_ops.zeros_like(total_loss),
|
||||
name="value")
|
||||
|
||||
|
||||
def _compute_weighted_loss(losses, weight):
|
||||
"""Computes the weighted loss.
|
||||
|
||||
Args:
|
||||
losses: A tensor of size [batch_size, d1, ... dN].
|
||||
weight: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` that returns the weighted loss.
|
||||
|
||||
Raises:
|
||||
ValueError: If `x` has dimension 0.
|
||||
ValueError: If the weight shape is not compatible with the losses shape or
|
||||
if the number of dimensions (rank) of either losses or weight is missing.
|
||||
"""
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
with ops.op_scope([x], name, "reduce_batch"):
|
||||
ndims = x.get_shape().ndims
|
||||
if ndims == 0:
|
||||
raise ValueError("Cannot reduce a scalar into batches.")
|
||||
elif ndims == 1:
|
||||
return x # Don't include a useless reduction.
|
||||
elif ndims:
|
||||
reduction_indices = math_ops.range(1, ndims)
|
||||
shape = [x.get_shape().dims[0]]
|
||||
else:
|
||||
reduction_indices = math_ops.range(1, array_ops.size(array_ops.shape(x)))
|
||||
shape = [None] # We don't know much about the shape, but it is rank 1.
|
||||
result = reduce_fn(x, reduction_indices=reduction_indices)
|
||||
losses = math_ops.to_float(losses)
|
||||
weight = math_ops.to_float(ops.convert_to_tensor(weight))
|
||||
|
||||
# Give a shape hint in case we have extra information.
|
||||
result.set_shape(shape)
|
||||
return result
|
||||
if losses.get_shape().ndims is None:
|
||||
raise ValueError("losses.get_shape().ndims cannot be None")
|
||||
if weight.get_shape().ndims is None:
|
||||
raise ValueError("weight.get_shape().ndims cannot be None")
|
||||
|
||||
total_loss = _scale_losses(losses, weight)
|
||||
num_present = _num_present(losses, weight)
|
||||
return _safe_mean(total_loss, num_present)
|
||||
|
||||
|
||||
def _reduce_batch_sum(x, name=None):
|
||||
"""Given a tensor `x`, sums across all dimensions except dimension 0.
|
||||
def _num_present(losses, weight, per_batch=False):
|
||||
"""Computes the number of elements in the loss function induced by `weight`.
|
||||
|
||||
Given a tensor with the number of dimensions > 1, this will sum across all
|
||||
dimensions except for dimension 0. This function is useful for summing the
|
||||
loss (error) across all examples in a batch when training. As an example,
|
||||
given a tensor of shape [batch_size, d1, d2], this function will sum across
|
||||
dimensions d1 and d2, returning a tensor of shape [batch_size].
|
||||
|
||||
Tensors of dimension 1 are returned as-is, while tensors of dimension 0
|
||||
raise a ValueError.
|
||||
A given weight tensor induces different numbers of usable elements in the
|
||||
`losses` tensor. The `weight` tensor is broadcast across `losses` for all
|
||||
possible dimensions. For example, if `losses` is a tensor of dimension
|
||||
[4, 5, 6, 3] and weight is a tensor of size [4, 5], then weight is, in effect,
|
||||
tiled to match the size of `losses`. Following this effective tile, the total
|
||||
number of present elements is the number of non-zero weights.
|
||||
|
||||
Args:
|
||||
x: A `Tensor` with dimension > 0.
|
||||
name: A name for the operation (optional).
|
||||
losses: A tensor of size [batch_size, d1, ... dN].
|
||||
weight: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
|
||||
per_batch: Whether to return the number of elements per batch or as a sum
|
||||
total.
|
||||
|
||||
Returns:
|
||||
A `Tensor` with values summed across all dimensions > 0.
|
||||
The number of present (non-zero) elements in the losses tensor. If
|
||||
`per_batch` is True, the value is returned as a tensor of size
|
||||
[batch_size]. Otherwise, a single scalar tensor is returned.
|
||||
"""
|
||||
# To ensure that dims of [2, 1] gets mapped to [2,]
|
||||
weight = array_ops.squeeze(weight)
|
||||
|
||||
# If the weight is a scalar, its easy to compute:
|
||||
if weight.get_shape().ndims == 0:
|
||||
batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
|
||||
[0], [1]), [])
|
||||
num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)),
|
||||
math_ops.to_float(batch_size))
|
||||
num_per_batch = math_ops.select(math_ops.equal(weight, 0),
|
||||
0.0, num_per_batch)
|
||||
num_per_batch = math_ops.mul(array_ops.ones(
|
||||
array_ops.reshape(batch_size, [1])), num_per_batch)
|
||||
return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
|
||||
|
||||
# First, count the number of nonzero weights:
|
||||
if weight.get_shape().ndims >= 1:
|
||||
reduction_indices = range(1, weight.get_shape().ndims)
|
||||
num_nonzero_per_batch = math_ops.reduce_sum(
|
||||
math_ops.to_float(math_ops.not_equal(weight, 0)),
|
||||
reduction_indices=reduction_indices)
|
||||
|
||||
# Next, determine the number of elements that weight would broadcast to:
|
||||
broadcast_dims = array_ops.slice(array_ops.shape(losses),
|
||||
[weight.get_shape().ndims], [-1])
|
||||
num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims))
|
||||
|
||||
num_per_batch = math_ops.mul(num_nonzero_per_batch, num_to_broadcast)
|
||||
return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
|
||||
|
||||
|
||||
def absolute_difference(predictions, targets, weight=1.0, scope=None):
|
||||
"""Adds an Absolute Difference loss to the training procedure.
|
||||
|
||||
`weight` acts as a coefficient for the loss. If a scalar is provided, then the
|
||||
loss is simply scaled by the given value. If `weight` is a tensor of size
|
||||
[batch_size], then the total loss for each sample of the batch is rescaled
|
||||
by the corresponding element in the `weight` vector. If the shape of
|
||||
`weight` matches the shape of `predictions`, then the loss of each
|
||||
measurable element of `predictions` is scaled by the corresponding value of
|
||||
`weight`.
|
||||
|
||||
Args:
|
||||
predictions: The predicted outputs.
|
||||
targets: The ground truth output tensor, same dimensions as 'predictions'.
|
||||
weight: Coefficients for the loss a scalar, a tensor of shape
|
||||
[batch_size] or a tensor whose shape matches `predictions`.
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` representing the loss value.
|
||||
|
||||
Raises:
|
||||
ValueError: If `x` has dimension 0.
|
||||
|
||||
ValueError: If the shape of `predictions` doesn't match that of `targets` or
|
||||
if the shape of `weight` is invalid.
|
||||
"""
|
||||
return _reduce_batch(x, math_ops.reduce_sum, name)
|
||||
with ops.op_scope([predictions, targets],
|
||||
scope, "sum_of_squares_loss") as scope:
|
||||
predictions.get_shape().assert_is_compatible_with(targets.get_shape())
|
||||
if weight is None:
|
||||
raise ValueError("`weight` cannot be None")
|
||||
predictions = math_ops.to_float(predictions)
|
||||
targets = math_ops.to_float(targets)
|
||||
losses = math_ops.abs(math_ops.sub(predictions, targets))
|
||||
return _compute_weighted_loss(losses, weight)
|
||||
|
||||
|
||||
def _reduce_to_scalar(x, name=None):
|
||||
"""Reduces losses to a scalar.
|
||||
|
||||
Given a tensor `x`, sums across all dimensions except dimension 0, then
|
||||
average across dimension 0.
|
||||
def sigmoid_cross_entropy(logits, multi_class_labels, weight=1.0,
|
||||
label_smoothing=0, scope=None):
|
||||
"""Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits.
|
||||
|
||||
Args:
|
||||
x: A `Tensor` with dimension > 0.
|
||||
name: A name for the operation (optional).
|
||||
logits: [batch_size, num_classes] logits outputs of the network .
|
||||
multi_class_labels: [batch_size, num_classes] target labels in (0, 1).
|
||||
weight: Coefficients for the loss. The tensor must be a scalar, a tensor of
|
||||
shape [batch_size] or shape [batch_size, num_classes].
|
||||
label_smoothing: If greater than 0 then smooth the labels.
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
|
||||
Returns:
|
||||
Caculate sum of losses per example, then average across batch.
|
||||
A scalar `Tensor` representing the loss value.
|
||||
"""
|
||||
with ops.op_scope([x], name, "scalar") as scope:
|
||||
return math_ops.reduce_mean(_reduce_batch_sum(x), name=scope)
|
||||
with ops.op_scope([logits, multi_class_labels],
|
||||
scope, "sigmoid_cross_entropy_loss"):
|
||||
return _cross_entropy(logits, multi_class_labels, weight,
|
||||
label_smoothing,
|
||||
activation_fn=nn.sigmoid_cross_entropy_with_logits)
|
||||
|
||||
|
||||
def _validate_predicted_and_target(predicted, target):
|
||||
# TODO(ptucker): Optionally add assert op for shape check, for cases when
|
||||
# shape is not fully defined at graph construction time?
|
||||
predicted.get_shape().assert_is_compatible_with(target.get_shape())
|
||||
tensor_util.assert_same_float_dtype([predicted, target])
|
||||
def softmax_cross_entropy(logits, onehot_labels, weight=1.0,
|
||||
label_smoothing=0, scope=None):
|
||||
"""Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits.
|
||||
|
||||
|
||||
def _raw_absolute(predicted, target, name=None):
|
||||
"""Computes and returns the per-example absolute loss.
|
||||
|
||||
Computes the per-example absolute value of the difference between
|
||||
the target and predicted tensors. The tensors must have the same
|
||||
shape.
|
||||
It can scale the loss by weight factor, and smooth the labels.
|
||||
|
||||
Args:
|
||||
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
|
||||
of predicted values.
|
||||
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
|
||||
target values. The shape of the target tensor should match the
|
||||
`predicted` tensor.
|
||||
name: A name for the operation (optional).
|
||||
logits: [batch_size, num_classes] logits outputs of the network .
|
||||
onehot_labels: [batch_size, num_classes] target one_hot_encoded labels.
|
||||
weight: Coefficients for the loss. The tensor must be a scalar or a tensor
|
||||
of shape [batch_size].
|
||||
label_smoothing: If greater than 0 then smooth the labels.
|
||||
scope: the scope for the operations performed in computing the loss.
|
||||
|
||||
Returns:
|
||||
A `[batch_size, dim_1, ..., dim_n]` tensor of per-example absolute losses.
|
||||
A scalar `Tensor` representing the loss value.
|
||||
"""
|
||||
with ops.op_scope([logits, onehot_labels],
|
||||
scope, "softmax_cross_entropy_loss"):
|
||||
return _cross_entropy(logits, onehot_labels, weight,
|
||||
label_smoothing,
|
||||
activation_fn=nn.softmax_cross_entropy_with_logits)
|
||||
|
||||
|
||||
def _cross_entropy(logits, onehot_labels, weight, label_smoothing,
|
||||
activation_fn):
|
||||
"""Adds a CrossEntropyLoss to the losses collection.
|
||||
|
||||
`weight` acts as a coefficient for the loss. If a scalar is provided,
|
||||
then the loss is simply scaled by the given value. If `weight` is a
|
||||
tensor of size [`batch_size`], then the loss weights apply to each
|
||||
corresponding sample.
|
||||
|
||||
Args:
|
||||
logits: [batch_size, num_classes] logits outputs of the network .
|
||||
onehot_labels: [batch_size, num_classes] target one_hot_encoded labels.
|
||||
weight: Coefficients for the loss. If the activation is SIGMOID, then the
|
||||
weight shape must be one of [1], [batch_size] or logits.shape().
|
||||
Otherwise, the weight shape must be either [1] or [batch_size].
|
||||
label_smoothing: If greater than 0 then smooth the labels.
|
||||
activation_fn: The activation function to use. The method must take three
|
||||
arguments, the logits, the labels, and an operation name.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` representing the loss value.
|
||||
|
||||
Raises:
|
||||
ValueError: If `predicted` and `target` shapes do not match.
|
||||
|
||||
ValueError: If the shape of `predictions` doesn't match that of `targets` or
|
||||
if the shape of `weight` is invalid or if `weight` is None.
|
||||
"""
|
||||
with ops.op_scope([predicted, target], name, "absolute_loss") as scope:
|
||||
predicted = ops.convert_to_tensor(predicted, name="predicted")
|
||||
target = ops.convert_to_tensor(target, name="target")
|
||||
_validate_predicted_and_target(predicted, target)
|
||||
return math_ops.abs(target - predicted, name=scope)
|
||||
logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape())
|
||||
if weight is None:
|
||||
raise ValueError("`weight` cannot be None")
|
||||
|
||||
num_classes = onehot_labels.get_shape()[1]
|
||||
onehot_labels = math_ops.cast(onehot_labels, logits.dtype)
|
||||
|
||||
if label_smoothing > 0:
|
||||
smooth_positives = 1.0 - label_smoothing
|
||||
smooth_negatives = label_smoothing / num_classes
|
||||
onehot_labels = onehot_labels * smooth_positives + smooth_negatives
|
||||
|
||||
losses = activation_fn(logits, onehot_labels, name="xentropy")
|
||||
return _compute_weighted_loss(losses, weight)
|
||||
|
||||
|
||||
def _raw_squared(predicted, target, name=None):
|
||||
"""Computes and returns the per-example squared loss, divided by 2.
|
||||
def log(predictions, targets, weight=1.0, epsilon=1e-7, scope=None):
|
||||
"""Adds a Log Loss term to the training procedure.
|
||||
|
||||
Computes the per-example squared difference between the target and
|
||||
predicted tensors. The tensors must have the same shape.
|
||||
`weight` acts as a coefficient for the loss. If a scalar is provided, then the
|
||||
loss is simply scaled by the given value. If `weight` is a tensor of size
|
||||
[batch_size], then the total loss for each sample of the batch is rescaled
|
||||
by the corresponding element in the `weight` vector. If the shape of
|
||||
`weight` matches the shape of `predictions`, then the loss of each
|
||||
measurable element of `predictions` is scaled by the corresponding value of
|
||||
`weight`.
|
||||
|
||||
Args:
|
||||
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
|
||||
of predicted values.
|
||||
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
|
||||
target values. The shape of the target tensor should match the
|
||||
`predicted` tensor.
|
||||
name: A name for the operation (optional).
|
||||
predictions: The predicted outputs.
|
||||
targets: The ground truth output tensor, same dimensions as 'predictions'.
|
||||
weight: Coefficients for the loss a scalar, a tensor of shape
|
||||
[batch_size] or a tensor whose shape matches `predictions`.
|
||||
epsilon: A small increment to add to avoid taking a log of zero.
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
|
||||
Returns:
|
||||
A `[batch_size, dim_1, ..., dim_n]` tensor of per-example squared losses.
|
||||
A scalar `Tensor` representing the loss value.
|
||||
|
||||
Raises:
|
||||
ValueError: If `predicted` and `target` shapes do not match.
|
||||
|
||||
ValueError: If the shape of `predictions` doesn't match that of `targets` or
|
||||
if the shape of `weight` is invalid.
|
||||
"""
|
||||
with ops.op_scope([predicted, target], name, "squared_loss") as scope:
|
||||
predicted = ops.convert_to_tensor(predicted, name="predicted")
|
||||
target = ops.convert_to_tensor(target, name="target")
|
||||
_validate_predicted_and_target(predicted, target)
|
||||
return math_ops.div(math_ops.square(target - predicted), 2.0, name=scope)
|
||||
with ops.op_scope([predictions, targets],
|
||||
scope, "log_loss") as scope:
|
||||
predictions.get_shape().assert_is_compatible_with(targets.get_shape())
|
||||
if weight is None:
|
||||
raise ValueError("`weight` cannot be None")
|
||||
predictions = math_ops.to_float(predictions)
|
||||
targets = math_ops.to_float(targets)
|
||||
losses = -math_ops.mul(
|
||||
targets,
|
||||
math_ops.log(predictions + epsilon)) - math_ops.mul(
|
||||
(1 - targets), math_ops.log(1 - predictions + epsilon))
|
||||
return _compute_weighted_loss(losses, weight)
|
||||
|
||||
|
||||
def absolute(predicted, target, name=None):
|
||||
"""Reduces absolute losses to a scalar.
|
||||
def sum_of_squares(predictions, targets, weight=1.0, scope=None):
|
||||
"""Adds a Sum-of-Squares loss to the training procedure.
|
||||
|
||||
`weight` acts as a coefficient for the loss. If a scalar is provided, then the
|
||||
loss is simply scaled by the given value. If `weight` is a tensor of size
|
||||
[batch_size], then the total loss for each sample of the batch is rescaled
|
||||
by the corresponding element in the `weight` vector. If the shape of
|
||||
`weight` matches the shape of `predictions`, then the loss of each
|
||||
measurable element of `predictions` is scaled by the corresponding value of
|
||||
`weight`.
|
||||
|
||||
Args:
|
||||
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
|
||||
of predicted values.
|
||||
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
|
||||
target values. The shape of the target tensor should match the
|
||||
`predicted` tensor.
|
||||
name: A name for the operation (optional).
|
||||
predictions: The predicted outputs.
|
||||
targets: The ground truth output tensor, same dimensions as 'predictions'.
|
||||
weight: Coefficients for the loss a scalar, a tensor of shape
|
||||
[batch_size] or a tensor whose shape matches `predictions`.
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
|
||||
Returns:
|
||||
Caculate sum of absolute losses per example, then average across batch.
|
||||
"""
|
||||
with ops.op_scope([predicted, target], name, "absolute_loss") as scope:
|
||||
return _reduce_to_scalar(_raw_absolute(predicted, target), name=scope)
|
||||
|
||||
|
||||
def squared(predicted, target, name=None):
|
||||
"""Reduces squared losses to a scalar.
|
||||
|
||||
Args:
|
||||
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
|
||||
of predicted values.
|
||||
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
|
||||
target values. The shape of the target tensor should match the
|
||||
`predicted` tensor.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
Caculate sum of squared losses per example, then average across batch.
|
||||
"""
|
||||
with ops.op_scope([predicted, target], name, "squared_loss") as scope:
|
||||
return _reduce_to_scalar(_raw_squared(predicted, target), name=scope)
|
||||
|
||||
|
||||
def logistic(logit, target, name=None):
|
||||
"""Calculates the logistic cross-entropy loss, averaged across batches.
|
||||
|
||||
**WARNING:** `logit` must be unscaled.
|
||||
See `tf.nn.sigmoid_cross_entropy_with_logits` for more details.
|
||||
|
||||
Args:
|
||||
logit: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
|
||||
of predicted logit values.
|
||||
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
|
||||
target values. The shape of the target tensor should match the
|
||||
`logit` tensor.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A scalar `tensor` of the logistic cross-entropy loss, averaged across
|
||||
batches.
|
||||
A scalar `Tensor` representing the loss value.
|
||||
|
||||
Raises:
|
||||
ValueError: If `logit` and `target` shapes do not match.
|
||||
ValueError: If the shape of `predictions` doesn't match that of `targets` or
|
||||
if the shape of `weight` is invalid.
|
||||
"""
|
||||
with ops.op_scope([logit, target], name, "logistic_loss") as scope:
|
||||
return _reduce_to_scalar(
|
||||
nn.sigmoid_cross_entropy_with_logits(logit, target), name=scope)
|
||||
with ops.op_scope([predictions, targets],
|
||||
scope, "sum_of_squares_loss") as scope:
|
||||
predictions.get_shape().assert_is_compatible_with(targets.get_shape())
|
||||
if weight is None:
|
||||
raise ValueError("`weight` cannot be None")
|
||||
predictions = math_ops.to_float(predictions)
|
||||
targets = math_ops.to_float(targets)
|
||||
losses = math_ops.square(math_ops.sub(predictions, targets))
|
||||
return _compute_weighted_loss(losses, weight)
|
||||
|
||||
|
||||
def softmax(logit, target, name=None):
|
||||
"""Calculates the softmax cross-entropy loss, averaged across batches.
|
||||
def sum_of_pairwise_squares(predictions, targets, weight=1.0, scope=None):
|
||||
"""Adds a pairwise-errors-squared loss to the training procedure.
|
||||
|
||||
**WARNING:** `logit` must be unscaled, while the `target` should be a
|
||||
normalized probability prediction.
|
||||
See `tf.nn.softmax_cross_entropy_with_logits` for more details.
|
||||
Unlike the sum_of_squares loss, which is a measure of the differences between
|
||||
corresponding elements of `predictions` and `targets`, sum_of_pairwise_squares
|
||||
is a measure of the differences between pairs of corresponding elements of
|
||||
`predictions` and `targets`.
|
||||
|
||||
For example, if `targets`=[a, b, c] and `predictions`=[x, y, z], there are
|
||||
three pairs of differences are summed to compute the loss:
|
||||
loss = [ ((a-b) - (x-y)).^2 + ((a-c) - (x-z)).^2 + ((b-c) - (y-z)).^2 ] / 3
|
||||
|
||||
Note that since the inputs are of size [batch_size, d0, ... dN], the
|
||||
corresponding pairs are computed within each batch sample but not across
|
||||
samples within a batch. For example, if `predictions` represents a batch of
|
||||
16 grayscale images of dimenion [batch_size, 100, 200], then the set of pairs
|
||||
is drawn from each image, but not across images.
|
||||
|
||||
`weight` acts as a coefficient for the loss. If a scalar is provided, then the
|
||||
loss is simply scaled by the given value. If `weight` is a tensor of size
|
||||
[batch_size], then the total loss for each sample of the batch is rescaled
|
||||
by the corresponding element in the `weight` vector.
|
||||
|
||||
Args:
|
||||
logit: Tensor of actual values. Shape must have rank 2, generally
|
||||
(batch, num_classes). num_classes must be > 1. For single-class
|
||||
regression, use `logistic`. Type must be `tf.float32` or `tf.float64`.
|
||||
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
|
||||
target values. The shape of the target tensor should match the
|
||||
`logit` tensor.
|
||||
name: A name for the operation (optional).
|
||||
predictions: The predicted outputs, a tensor of size [batch_size, d0, .. dN]
|
||||
where N+1 is the total number of dimensions in `predictions`.
|
||||
targets: The ground truth output tensor, whose shape must match the shape of
|
||||
the `predictions` tensor.
|
||||
weight: Coefficients for the loss a scalar, a tensor of shape [batch_size]
|
||||
or a tensor whose shape matches `predictions`.
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
|
||||
Returns:
|
||||
A scalar `tensor` of the softmax cross-entropy loss, averaged across
|
||||
batches.
|
||||
A scalar `Tensor` representing the loss value.
|
||||
|
||||
Raises:
|
||||
ValueError: If `logit` and `target` shapes do not match.
|
||||
ValueError: If the shape of `predictions` doesn't match that of `targets` or
|
||||
if the shape of `weight` is invalid.
|
||||
"""
|
||||
with ops.op_scope([logit, target], name, "softmax_loss") as scope:
|
||||
shape = logit.get_shape().with_rank(2)
|
||||
if shape.dims[1] and shape.dims[1] < 2:
|
||||
raise ValueError(
|
||||
"Invalid shape %s; use logistic() instead for only 1 class." %
|
||||
shape)
|
||||
return _reduce_to_scalar(
|
||||
nn.softmax_cross_entropy_with_logits(logit, target), name=scope)
|
||||
with ops.op_scope([predictions, targets],
|
||||
scope, "sum_of_pairwise_squares_loss") as scope:
|
||||
predictions.get_shape().assert_is_compatible_with(targets.get_shape())
|
||||
if weight is None:
|
||||
raise ValueError("`weight` cannot be None")
|
||||
predictions = math_ops.to_float(predictions)
|
||||
targets = math_ops.to_float(targets)
|
||||
weight = math_ops.to_float(ops.convert_to_tensor(weight))
|
||||
|
||||
diffs = math_ops.sub(predictions, targets)
|
||||
|
||||
# Need to verify here since the function doesn't use _compute_weighted_loss
|
||||
if diffs.get_shape().ndims is None:
|
||||
raise ValueError("diffs.get_shape().ndims cannot be None")
|
||||
if weight.get_shape().ndims is None:
|
||||
raise ValueError("weight.get_shape().ndims cannot be None")
|
||||
|
||||
reduction_indices = range(1, diffs.get_shape().ndims)
|
||||
|
||||
sum_squares_diff_per_batch = math_ops.reduce_sum(
|
||||
math_ops.square(diffs),
|
||||
reduction_indices=reduction_indices)
|
||||
num_present_per_batch = _num_present(diffs, weight, per_batch=True)
|
||||
|
||||
term1 = 2.0 * math_ops.div(sum_squares_diff_per_batch,
|
||||
num_present_per_batch)
|
||||
|
||||
sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
|
||||
term2 = 2.0 * math_ops.div(math_ops.square(sum_diff),
|
||||
math_ops.square(num_present_per_batch))
|
||||
|
||||
loss = _scale_losses(term1 - term2, weight)
|
||||
|
||||
return math_ops.select(math_ops.reduce_sum(num_present_per_batch) > 0,
|
||||
loss,
|
||||
array_ops.zeros_like(loss),
|
||||
name="value")
|
||||
|
||||
|
||||
def cosine_distance(predictions, targets, dim, weight=1.0, scope=None):
|
||||
"""Adds a cosine-distance loss to the training procedure.
|
||||
|
||||
Note that the function assumes that the predictions and targets are already
|
||||
unit-normalized.
|
||||
|
||||
Args:
|
||||
predictions: An arbitrary matrix.
|
||||
targets: A `Tensor` whose shape matches 'predictions'
|
||||
dim: The dimension along which the cosine distance is computed.
|
||||
weight: Coefficients for the loss a scalar, a tensor of shape
|
||||
[batch_size] or a tensor whose shape matches `predictions`.
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` representing the loss value.
|
||||
|
||||
Raises:
|
||||
ValueError: If predictions.shape doesn't match targets.shape, if the ignore
|
||||
mask is provided and its shape doesn't match targets.shape or if
|
||||
the ignore mask is not boolean valued.
|
||||
"""
|
||||
with ops.op_scope([predictions, targets],
|
||||
scope, "cosine_distance_loss") as scope:
|
||||
predictions.get_shape().assert_is_compatible_with(targets.get_shape())
|
||||
if weight is None:
|
||||
raise ValueError("`weight` cannot be None")
|
||||
|
||||
predictions = math_ops.to_float(predictions)
|
||||
targets = math_ops.to_float(targets)
|
||||
|
||||
radial_diffs = math_ops.mul(predictions, targets)
|
||||
losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,])
|
||||
return _compute_weighted_loss(losses, weight)
|
||||
|
@ -23,250 +23,733 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util
|
||||
|
||||
pi = 3.14
|
||||
indiana_pi = 3.2 # https://en.wikipedia.org/wiki/Indiana_Pi_Bill
|
||||
|
||||
class AbsoluteDifferenceLossTest(tf.test.TestCase):
|
||||
|
||||
class AbsoluteLossTest(tf.test.TestCase):
|
||||
def setUp(self):
|
||||
self._predictions = tf.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
|
||||
self._targets = tf.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
|
||||
def testAbsoluteLoss(self):
|
||||
def testValueErrorThrownWhenWeightIsNone(self):
|
||||
with self.test_session():
|
||||
actual = tf.constant([pi], name="pi")
|
||||
actual_placeholder = tf.placeholder(tf.float32)
|
||||
label = tf.constant([indiana_pi], name="lbl")
|
||||
label_placeholder = tf.placeholder(tf.float32, name="lbl_ph")
|
||||
expected_loss = abs(indiana_pi - pi)
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.absolute_difference(
|
||||
self._predictions, self._predictions, weight=None)
|
||||
|
||||
# Both shapes are set.
|
||||
both_shapes_loss = tf.contrib.losses.absolute(actual, label)
|
||||
tf.initialize_all_variables().run()
|
||||
np.testing.assert_almost_equal(
|
||||
both_shapes_loss.eval(), expected_loss, decimal=6)
|
||||
|
||||
# No shape for 'actual' - check that the loss layer can be created.
|
||||
no_actual_shape_loss = tf.contrib.losses.absolute(
|
||||
actual_placeholder, label)
|
||||
tf.initialize_all_variables().run()
|
||||
np.testing.assert_almost_equal(
|
||||
no_actual_shape_loss.eval({actual_placeholder: [pi]}),
|
||||
expected_loss, decimal=6)
|
||||
|
||||
# No shape for 'label' - check that the loss layer can be created.
|
||||
no_label_shape_loss = tf.contrib.losses.absolute(
|
||||
actual, label_placeholder)
|
||||
tf.initialize_all_variables().run()
|
||||
np.testing.assert_almost_equal(
|
||||
no_label_shape_loss.eval({label_placeholder: [indiana_pi]}),
|
||||
expected_loss, decimal=6)
|
||||
|
||||
# No shapes.
|
||||
no_shape_loss = tf.contrib.losses.absolute(
|
||||
actual_placeholder, label_placeholder)
|
||||
tf.initialize_all_variables().run()
|
||||
np.testing.assert_almost_equal(
|
||||
no_shape_loss.eval({label_placeholder: [indiana_pi],
|
||||
actual_placeholder: [pi]}),
|
||||
expected_loss, decimal=6)
|
||||
|
||||
# Evaluate the previous one again, but this time with different
|
||||
# (matching) shapes. This should still work.
|
||||
np.testing.assert_almost_equal(
|
||||
no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi],
|
||||
actual_placeholder: [pi, pi]}),
|
||||
expected_loss, decimal=6)
|
||||
|
||||
|
||||
class SquaredLossTest(tf.test.TestCase):
|
||||
|
||||
def testSquaredLoss(self):
|
||||
def testAllCorrectNoLossWeight(self):
|
||||
loss = tf.contrib.losses.absolute_difference(
|
||||
self._predictions, self._predictions)
|
||||
with self.test_session():
|
||||
actual = tf.constant([pi], name="pi")
|
||||
actual_placeholder = tf.placeholder(tf.float32)
|
||||
label = tf.constant([indiana_pi], name="lbl")
|
||||
label_placeholder = tf.placeholder(tf.float32, name="lbl_ph")
|
||||
expected_loss = (indiana_pi - pi) * (indiana_pi - pi) / 2
|
||||
self.assertAlmostEqual(0.0, loss.eval(), 3)
|
||||
|
||||
# Both shapes are set.
|
||||
both_shapes_loss = tf.contrib.losses.squared(actual, label)
|
||||
tf.initialize_all_variables().run()
|
||||
np.testing.assert_almost_equal(
|
||||
both_shapes_loss.eval(), expected_loss, decimal=6)
|
||||
|
||||
# No shape for 'actual' - check that the loss layer can be created.
|
||||
no_actual_shape_loss = tf.contrib.losses.squared(
|
||||
actual_placeholder, label)
|
||||
tf.initialize_all_variables().run()
|
||||
np.testing.assert_almost_equal(
|
||||
no_actual_shape_loss.eval({actual_placeholder: [pi]}),
|
||||
expected_loss, decimal=6)
|
||||
|
||||
# No shape for 'label' - check that the loss layer can be created.
|
||||
no_label_shape_loss = tf.contrib.losses.squared(
|
||||
actual, label_placeholder)
|
||||
tf.initialize_all_variables().run()
|
||||
np.testing.assert_almost_equal(
|
||||
no_label_shape_loss.eval({label_placeholder: [indiana_pi]}),
|
||||
expected_loss,
|
||||
decimal=6)
|
||||
|
||||
# No shapes.
|
||||
no_shape_loss = tf.contrib.losses.squared(
|
||||
actual_placeholder, label_placeholder)
|
||||
tf.initialize_all_variables().run()
|
||||
np.testing.assert_almost_equal(
|
||||
no_shape_loss.eval({label_placeholder: [indiana_pi],
|
||||
actual_placeholder: [pi]}),
|
||||
expected_loss, decimal=6)
|
||||
|
||||
# Evaluate the previous one again, but this time with different
|
||||
# (matching) shapes. This should still work.
|
||||
np.testing.assert_almost_equal(
|
||||
no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi],
|
||||
actual_placeholder: [pi, pi]}),
|
||||
expected_loss, decimal=6)
|
||||
|
||||
|
||||
class LogisticTest(tf.test.TestCase):
|
||||
|
||||
def _expected_loss(self, logit, target):
|
||||
sigmoid = 1.0 / (1.0 + np.exp(-logit))
|
||||
logistic_loss = (target * -np.log(sigmoid)) - (
|
||||
(1.0 - target) * np.log(1.0 - sigmoid))
|
||||
batch_losses = np.sum(logistic_loss, 1)
|
||||
|
||||
return np.sum(batch_losses) / len(batch_losses)
|
||||
|
||||
def testSimple(self):
|
||||
logit = np.array([[9.45, -42], [4.2, 1], [-0.6, 20]])
|
||||
target = np.array([[0.8, 0.9], [0.45, 0.99999], [0.1, 0.0006]])
|
||||
def testNonZeroLoss(self):
|
||||
loss = tf.contrib.losses.absolute_difference(
|
||||
self._predictions, self._targets)
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.logistic(tf.constant(logit), tf.constant(target))
|
||||
self.assertAllClose(self._expected_loss(logit, target), loss.eval())
|
||||
self.assertAlmostEqual(5.5, loss.eval(), 3)
|
||||
|
||||
def testComplex(self):
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
weight = 2.3
|
||||
loss = tf.contrib.losses.absolute_difference(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
# [batch] and [batch,1] work the same.
|
||||
loss3x0 = tf.contrib.losses.logistic(
|
||||
tf.constant([-1.0, 3.0, -3.0]),
|
||||
tf.constant([0.3, 0.1, 0.4]))
|
||||
tf.initialize_all_variables().run()
|
||||
self.assertAllClose(1.536812, loss3x0.eval())
|
||||
self.assertAlmostEqual(5.5 * weight, loss.eval(), 3)
|
||||
|
||||
expected_loss = 1.536812
|
||||
actual3x1 = [[-1.0], [3.0], [-3.0]]
|
||||
label3x1 = [[0.3], [0.1], [0.4]]
|
||||
loss3x1 = tf.contrib.losses.logistic(
|
||||
tf.constant(actual3x1), tf.constant(label3x1))
|
||||
tf.initialize_all_variables().run()
|
||||
self.assertAllClose(expected_loss, loss3x1.eval())
|
||||
|
||||
# Batch average stays the same with repeats of the same examples.
|
||||
loss9x1 = tf.contrib.losses.logistic(
|
||||
tf.constant(actual3x1 * 3), tf.constant(label3x1 * 3))
|
||||
tf.initialize_all_variables().run()
|
||||
self.assertAllClose(expected_loss, loss9x1.eval())
|
||||
|
||||
# Loss stays the same when adding another class with 0 loss.
|
||||
loss3x2 = tf.contrib.losses.logistic(
|
||||
tf.constant([[-1.0, 100.0], [3.0, -100.0], [-3.0, -100.0]]),
|
||||
tf.constant([[0.3, 1.0], [0.1, 0.0], [0.4, 0.0]]))
|
||||
tf.initialize_all_variables().run()
|
||||
self.assertAllClose(expected_loss, loss3x2.eval())
|
||||
|
||||
# Loss stays the same with additional x1 dimension.
|
||||
loss3x1x2 = tf.contrib.losses.logistic(
|
||||
tf.constant([[[-1.0, 100.0]], [[3.0, -100.0]], [[-3.0, -100.0]]]),
|
||||
tf.constant([[[0.3, 1.0]], [[0.1, 0.0]], [[0.4, 0.0]]]))
|
||||
tf.initialize_all_variables().run()
|
||||
self.assertAllClose(expected_loss, loss3x1x2.eval())
|
||||
|
||||
# We have set one label value to be out of range (the -0.4) and
|
||||
# expect the absence of a crash since we did not set validate=True
|
||||
loss = tf.contrib.losses.logistic(
|
||||
tf.constant([[[-1.0, 100.0]], [[3.0, -100.0]], [[-3.0, -100.0]]]),
|
||||
tf.constant([[[0.3, 1.0]], [[0.1, 0.0]], [[-0.4, 0.0]]]))
|
||||
tf.initialize_all_variables().run()
|
||||
loss.eval()
|
||||
|
||||
def testLogisticVsSoftmax(self):
|
||||
def testNonZeroLossWithScalarTensorWeight(self):
|
||||
weight = 2.3
|
||||
loss = tf.contrib.losses.absolute_difference(
|
||||
self._predictions, self._targets, tf.constant(weight))
|
||||
with self.test_session():
|
||||
# Each logit = L and target = T used for logistic_loss corresponds to
|
||||
# logits [a, b] where a - b = L and targets [T, 1 - T] for
|
||||
# softmax_loss.
|
||||
self.assertAlmostEqual(5.5 * weight, loss.eval(), 3)
|
||||
|
||||
expected_loss = (0.69314718 + 1.01326168 + 2.10692811) / 3.0
|
||||
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
|
||||
weight = tf.constant([1.2, 0.0], shape=[2,])
|
||||
loss = tf.contrib.losses.absolute_difference(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(5.6, loss.eval(), 3)
|
||||
|
||||
logistic_loss = tf.contrib.losses.logistic(
|
||||
tf.constant([0.0, 1.0, 2.0]),
|
||||
tf.constant([0.5, 0.3, 0.01]))
|
||||
tf.initialize_all_variables().run()
|
||||
self.assertAllClose(expected_loss, logistic_loss.eval())
|
||||
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
|
||||
weight = tf.constant([1.2, 0.0], shape=[2, 1])
|
||||
loss = tf.contrib.losses.absolute_difference(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(5.6, loss.eval(), 3)
|
||||
|
||||
softmax_loss = tf.contrib.losses.softmax(
|
||||
tf.constant([[1.0, 1.0], [2.0, 1.0], [3.0, 1.0]]),
|
||||
tf.constant([[0.5, 0.5], [0.3, 0.7], [0.01, 0.99]]))
|
||||
tf.initialize_all_variables().run()
|
||||
self.assertAllClose(expected_loss, softmax_loss.eval())
|
||||
def testNonZeroLossWithSampleSpecificWeights(self):
|
||||
weight = tf.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
|
||||
loss = tf.contrib.losses.absolute_difference(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(16.6, loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
|
||||
weight = tf.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
|
||||
loss = tf.contrib.losses.absolute_difference(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(6.0, loss.eval(), 3)
|
||||
|
||||
def testLossWithSampleSpecificWeightsAllZero(self):
|
||||
weight = tf.zeros((2, 3))
|
||||
loss = tf.contrib.losses.absolute_difference(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(0.0, loss.eval(), 3)
|
||||
|
||||
|
||||
class SoftmaxTest(tf.test.TestCase):
|
||||
class SoftmaxCrossEntropyLossTest(tf.test.TestCase):
|
||||
|
||||
def testNoneWeightRaisesValueError(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1]])
|
||||
with self.test_session():
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.softmax_cross_entropy(logits, labels, weight=None)
|
||||
|
||||
def testAllCorrect(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0]])
|
||||
loss = tf.contrib.losses.softmax(logits, labels)
|
||||
labels = tf.constant([[1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1]])
|
||||
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels)
|
||||
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
|
||||
def testAllWrong(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[0, 0, 1],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0]])
|
||||
|
||||
with self.test_session():
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[0.0, 0.0, 1.0],
|
||||
[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0]])
|
||||
loss = tf.contrib.losses.softmax(logits, labels)
|
||||
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels)
|
||||
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||
|
||||
def testSoftmax(self):
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[0, 0, 1],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0]])
|
||||
weight = 2.3
|
||||
with self.test_session():
|
||||
# [batch] and [batch,1] fail, softmax_loss is only for multiclass.
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "must have rank 2", tf.contrib.losses.softmax,
|
||||
tf.constant([-100.0, 10.0, 0.0]),
|
||||
tf.constant([1.0, 1.0, 1.0]))
|
||||
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
|
||||
self.assertAlmostEqual(loss.eval(), weight * 10.0, 3)
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "only 1 class", tf.contrib.losses.softmax,
|
||||
tf.constant([[-100.0], [10.0], [0.0]]),
|
||||
tf.constant([[1.0], [1.0], [1.0]]))
|
||||
def testNonZeroLossWithScalarTensorWeight(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[0, 0, 1],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0]])
|
||||
weight = 2.3
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.softmax_cross_entropy(
|
||||
logits, labels, tf.constant(weight))
|
||||
self.assertAlmostEqual(loss.eval(), weight * 10.0, 3)
|
||||
|
||||
expected_loss = 3.173363
|
||||
loss3x2 = tf.contrib.losses.softmax(
|
||||
tf.constant([[-1.0, 1.0], [0.0, 0.0], [10.0, -1.0]]),
|
||||
tf.constant([[0.5, 0.5], [0.3, 0.7], [0.3, 0.7]]))
|
||||
tf.initialize_all_variables().run()
|
||||
self.assertAllClose(expected_loss, loss3x2.eval())
|
||||
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[0, 0, 1],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0]])
|
||||
weight = tf.constant([1.2, 3.4, 5.6], shape=[3])
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
|
||||
self.assertAlmostEqual(loss.eval(), (1.2 + 3.4 + 5.6) * 10.0 / 3.0, 3)
|
||||
|
||||
# Loss stays the same when adding another negative class.
|
||||
loss3x3 = tf.contrib.losses.softmax(
|
||||
tf.constant(
|
||||
[[-1.0, 1.0, -100.0], [0.0, 0.0, -100.0], [10.0, -1.0, -100.0]]),
|
||||
tf.constant([[0.5, 0.5, 0.0], [0.3, 0.7, 0.0], [0.3, 0.7, 0.0]]))
|
||||
tf.initialize_all_variables().run()
|
||||
self.assertAllClose(expected_loss, loss3x3.eval())
|
||||
def testAllWrongAllMissing(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[0, 0, 1],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0]])
|
||||
weight = tf.constant([0, 0, 0], shape=[3])
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
|
||||
# Fails for rank > 2.
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "must have rank 2", tf.contrib.losses.softmax,
|
||||
tf.constant([[[-1.0, 1.0]], [[0.0, 0.0]], [[10.0, -1.0]]]),
|
||||
tf.constant([[[0.5, 0.5]], [[0.3, 0.7]], [[0.3, 0.7]]]))
|
||||
def testSomeMissing(self):
|
||||
logits = tf.constant([[10.0, 0.0, 0.0],
|
||||
[0.0, 10.0, 0.0],
|
||||
[0.0, 0.0, 10.0]])
|
||||
labels = tf.constant([[0, 0, 1],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0]])
|
||||
weight = tf.constant([1.2, 0, 0], shape=[3])
|
||||
with self.test_session():
|
||||
loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight)
|
||||
self.assertAlmostEqual(loss.eval(), 12.0, 3)
|
||||
|
||||
def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([[100.0, -100.0, -100.0],
|
||||
[-100.0, 100.0, -100.0],
|
||||
[-100.0, -100.0, 100.0]])
|
||||
labels = tf.constant([[1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1]])
|
||||
weight = tf.constant([[3, 4, 5],
|
||||
[2, 6, 0],
|
||||
[8, 0, 1]])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.softmax_cross_entropy(
|
||||
logits, labels, weight=weight).eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
class SigmoidCrossEntropyLossTest(tf.test.TestCase):
|
||||
|
||||
def testAllCorrectSigmoid(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([[100.0, -100.0, -100.0],
|
||||
[-100.0, 100.0, -100.0],
|
||||
[-100.0, -100.0, 100.0]])
|
||||
labels = tf.constant([[1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1]])
|
||||
loss = tf.contrib.losses.sigmoid_cross_entropy(logits, labels)
|
||||
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
|
||||
def testAllWrongSigmoid(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([[100.0, -100.0, -100.0],
|
||||
[-100.0, 100.0, -100.0],
|
||||
[-100.0, -100.0, 100.0]])
|
||||
labels = tf.constant([[0, 0, 1],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0]])
|
||||
loss = tf.contrib.losses.sigmoid_cross_entropy(logits, labels)
|
||||
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
|
||||
|
||||
def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([[100.0, -100.0, -100.0],
|
||||
[-100.0, 100.0, -100.0],
|
||||
[-100.0, -100.0, 100.0]])
|
||||
labels = tf.constant([[0, 0, 1],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0]])
|
||||
weight = tf.constant([[3, 4, 5],
|
||||
[2, 6, 0],
|
||||
[8, 0, 1]])
|
||||
loss = tf.contrib.losses.sigmoid_cross_entropy(
|
||||
logits, labels, weight=weight)
|
||||
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 1700.0 / 7.0, 3)
|
||||
|
||||
def testMultiCorrectSigmoid(self):
|
||||
logits = tf.constant([[100.0, -100.0, 100.0],
|
||||
[100.0, 100.0, -100.0],
|
||||
[-100.0, 100.0, 100.0]])
|
||||
labels = tf.constant([[1, 0, 1],
|
||||
[1, 1, 0],
|
||||
[0, 1, 1]])
|
||||
loss = tf.contrib.losses.sigmoid_cross_entropy(logits, labels)
|
||||
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
|
||||
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
|
||||
|
||||
class LogLossTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
predictions = np.asarray([.9, .2, .2, .8, .4, .6]).reshape((2, 3))
|
||||
targets = np.asarray([1.0, 0.0, 1.0, 1.0, 0.0, 0.0]).reshape((2, 3))
|
||||
|
||||
self._np_predictions = predictions
|
||||
self._np_targets = targets
|
||||
|
||||
epsilon = 1e-7
|
||||
self._expected_losses = np.multiply(
|
||||
targets, np.log(predictions + epsilon)) + np.multiply(
|
||||
1 - targets, np.log(1 - predictions + epsilon))
|
||||
|
||||
self._predictions = tf.constant(predictions)
|
||||
self._targets = tf.constant(targets)
|
||||
|
||||
def testValueErrorThrownWhenWeightIsNone(self):
|
||||
with self.test_session():
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.log(self._targets, self._targets, weight=None)
|
||||
|
||||
def testAllCorrectNoLossWeight(self):
|
||||
loss = tf.contrib.losses.log(self._targets, self._targets)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(0.0, loss.eval(), 3)
|
||||
|
||||
def testAllCorrectNoLossWeightWithPlaceholder(self):
|
||||
tf_predictions = tf.placeholder(tf.float32, shape=self._np_targets.shape)
|
||||
loss = tf.contrib.losses.log(tf_predictions, self._targets)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(0.0, loss.eval(feed_dict={
|
||||
tf_predictions: self._np_targets}), 3)
|
||||
|
||||
def testNonZeroLoss(self):
|
||||
loss = tf.contrib.losses.log(self._predictions, self._targets)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
|
||||
loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
weight = 2.3
|
||||
loss = tf.contrib.losses.log(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(weight * -np.sum(self._expected_losses) / 6.0,
|
||||
loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithScalarTensorWeight(self):
|
||||
weight = 2.3
|
||||
loss = tf.contrib.losses.log(
|
||||
self._predictions, self._targets, tf.constant(weight))
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(weight * -np.sum(self._expected_losses) / 6.0,
|
||||
loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithScalarTensorWeightAndPlaceholder(self):
|
||||
tf_predictions = tf.placeholder(tf.float32,
|
||||
shape=self._np_predictions.shape)
|
||||
weight = 2.3
|
||||
loss = tf.contrib.losses.log(
|
||||
tf_predictions, self._targets, tf.constant(weight))
|
||||
with self.test_session() as sess:
|
||||
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
|
||||
self.assertAlmostEqual(weight * -np.sum(self._expected_losses) / 6.0,
|
||||
loss, 3)
|
||||
|
||||
def testNonZeroLossWithScalarTensorWeightAndPlaceholderWithRankOnly(self):
|
||||
tf_predictions = tf.placeholder(tf.float32, shape=[None, None])
|
||||
weight = 2.3
|
||||
loss = tf.contrib.losses.log(
|
||||
tf_predictions, self._targets, tf.constant(weight))
|
||||
with self.test_session() as sess:
|
||||
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
|
||||
self.assertAlmostEqual(weight * -np.sum(self._expected_losses) / 6.0,
|
||||
loss, 3)
|
||||
|
||||
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
|
||||
weight = tf.constant([1.2, 3.4], shape=[2])
|
||||
expectedes = np.multiply(
|
||||
self._expected_losses,
|
||||
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
|
||||
loss = tf.contrib.losses.log(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(-np.sum(expectedes) / 6.0,
|
||||
loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self):
|
||||
weight = tf.constant([1.2, 0], shape=[2])
|
||||
expectedes = np.multiply(
|
||||
self._expected_losses,
|
||||
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape((2, 3)))
|
||||
loss = tf.contrib.losses.log(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(-np.sum(expectedes) / 3.0,
|
||||
loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self):
|
||||
weight = tf.constant([1.2, 0], shape=[2, 1])
|
||||
expectedes = np.multiply(
|
||||
self._expected_losses,
|
||||
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape((2, 3)))
|
||||
loss = tf.contrib.losses.log(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(-np.sum(expectedes) / 3.0,
|
||||
loss.eval(), 3)
|
||||
|
||||
def testWeightsWithSameNumDimsButWrongShapeThrowsException(self):
|
||||
weight = tf.constant(np.random.normal(size=(2, 4)), shape=[2, 4])
|
||||
with self.test_session():
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.log(self._predictions, self._targets, weight)
|
||||
|
||||
def testNonZeroLossWithMeasurementSpecificWeights(self):
|
||||
weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
|
||||
expectedes = np.multiply(self._expected_losses, weight)
|
||||
|
||||
loss = tf.contrib.losses.log(
|
||||
self._predictions,
|
||||
self._targets,
|
||||
weight=tf.constant(weight, shape=(2, 3)))
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(-np.sum(expectedes) / 5.0, loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
|
||||
weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
|
||||
expectedes = np.multiply(self._expected_losses, weight)
|
||||
|
||||
tf_predictions = tf.placeholder(tf.float32, shape=[2, 3])
|
||||
loss = tf.contrib.losses.log(
|
||||
tf_predictions,
|
||||
self._targets,
|
||||
weight=tf.constant(weight, shape=(2, 3)))
|
||||
|
||||
with self.test_session() as sess:
|
||||
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
|
||||
self.assertAlmostEqual(-np.sum(expectedes) / 5.0, loss, 3)
|
||||
|
||||
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
|
||||
weight = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3))
|
||||
expectedes = np.multiply(self._expected_losses, weight)
|
||||
|
||||
loss = tf.contrib.losses.log(
|
||||
self._predictions,
|
||||
self._targets,
|
||||
weight=tf.constant(weight, shape=(2, 3)))
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(-np.sum(expectedes), loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
|
||||
weight = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3))
|
||||
expectedes = np.multiply(self._expected_losses, weight)
|
||||
|
||||
tf_predictions = tf.placeholder(tf.float32, shape=[2, 3])
|
||||
tf_weight = tf.constant(weight, shape=(2, 3))
|
||||
loss = tf.contrib.losses.log(tf_predictions, self._targets, tf_weight)
|
||||
|
||||
with self.test_session() as sess:
|
||||
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
|
||||
self.assertAlmostEqual(-np.sum(expectedes), loss, 3)
|
||||
|
||||
def testLossWithSampleSpecificWeightsAllZero(self):
|
||||
tf_weight = tf.zeros(shape=(2, 3))
|
||||
loss = tf.contrib.losses.log(
|
||||
self._predictions, self._targets, tf_weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(0.0, loss.eval(), 3)
|
||||
|
||||
|
||||
class SumOfSquaresLossTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._predictions = tf.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
|
||||
self._targets = tf.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
|
||||
def testValueErrorThrownWhenWeightIsNone(self):
|
||||
with self.test_session():
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.sum_of_squares(
|
||||
self._predictions, self._predictions, weight=None)
|
||||
|
||||
def testAllCorrectNoLossWeight(self):
|
||||
loss = tf.contrib.losses.sum_of_squares(
|
||||
self._predictions, self._predictions)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(0.0, loss.eval(), 3)
|
||||
|
||||
def testNonZeroLoss(self):
|
||||
loss = tf.contrib.losses.sum_of_squares(
|
||||
self._predictions, self._targets)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(49.5, loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
weight = 2.3
|
||||
loss = tf.contrib.losses.sum_of_squares(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(49.5 * weight, loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithScalarTensorWeight(self):
|
||||
weight = 2.3
|
||||
loss = tf.contrib.losses.sum_of_squares(
|
||||
self._predictions, self._targets, tf.constant(weight))
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(49.5 * weight, loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
|
||||
weight = tf.constant([1.2, 3.4], shape=[2,])
|
||||
loss = tf.contrib.losses.sum_of_squares(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
|
||||
weight = tf.constant([1.2, 3.4], shape=[2, 1])
|
||||
loss = tf.contrib.losses.sum_of_squares(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithSampleSpecificWeights(self):
|
||||
weight = tf.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
|
||||
loss = tf.contrib.losses.sum_of_squares(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(587 / 5.0, loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
|
||||
weight = tf.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
|
||||
loss = tf.contrib.losses.sum_of_squares(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(18.0, loss.eval(), 3)
|
||||
|
||||
def testLossWithSampleSpecificWeightsAllZero(self):
|
||||
weight = tf.zeros((2, 3))
|
||||
loss = tf.contrib.losses.sum_of_squares(
|
||||
self._predictions, self._targets, weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(0.0, loss.eval(), 3)
|
||||
|
||||
|
||||
class SumOfPairwiseSquaresLossTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._predictions = np.array([[4, 8, 12],
|
||||
[8, 1, 3]])
|
||||
self._targets = np.array([[1, 9, 2],
|
||||
[-5, -5, 7]])
|
||||
|
||||
batch_size, dims = self._targets.shape
|
||||
|
||||
# Compute the expected loss 'manually'.
|
||||
total = np.zeros((batch_size, 1))
|
||||
for b in range(batch_size):
|
||||
for i in range(dims):
|
||||
for j in range(dims):
|
||||
x = self._predictions[b, i].item() - self._predictions[b, j].item()
|
||||
y = self._targets[b, i].item() - self._targets[b, j].item()
|
||||
tmp = (x-y) * (x-y)
|
||||
total[b] += tmp
|
||||
|
||||
self._expected_losses = np.divide(total, 9.0)
|
||||
|
||||
def testValueErrorThrownWhenWeightIsNone(self):
|
||||
with self.test_session():
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.sum_of_pairwise_squares(
|
||||
predictions=tf.constant(self._targets),
|
||||
targets=tf.constant(self._targets),
|
||||
weight=None)
|
||||
|
||||
def testAllCorrectNoLossWeight(self):
|
||||
loss = tf.contrib.losses.sum_of_pairwise_squares(
|
||||
predictions=tf.constant(self._targets),
|
||||
targets=tf.constant(self._targets))
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(0.0, loss.eval(), 3)
|
||||
|
||||
def testNonZeroLoss(self):
|
||||
loss = tf.contrib.losses.sum_of_pairwise_squares(
|
||||
predictions=tf.constant(self._predictions),
|
||||
targets=tf.constant(self._targets))
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
weight = 2.3
|
||||
loss = tf.contrib.losses.sum_of_pairwise_squares(
|
||||
predictions=tf.constant(self._predictions),
|
||||
targets=tf.constant(self._targets),
|
||||
weight=weight)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(weight * np.sum(self._expected_losses),
|
||||
loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithScalarTensorWeight(self):
|
||||
weight = 2.3
|
||||
loss = tf.contrib.losses.sum_of_pairwise_squares(
|
||||
predictions=tf.constant(self._predictions),
|
||||
targets=tf.constant(self._targets),
|
||||
weight=tf.constant(weight))
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(weight * np.sum(self._expected_losses),
|
||||
loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self):
|
||||
weight = 2.3
|
||||
tf_predictions = tf.placeholder(tf.float32, shape=self._predictions.shape)
|
||||
tf_targets = tf.placeholder(tf.float32, shape=self._targets.shape)
|
||||
loss = tf.contrib.losses.sum_of_pairwise_squares(
|
||||
predictions=tf_predictions,
|
||||
targets=tf_targets,
|
||||
weight=tf.constant(weight))
|
||||
with self.test_session() as sess:
|
||||
loss = sess.run(loss, feed_dict={
|
||||
tf_predictions: self._predictions,
|
||||
tf_targets: self._targets,
|
||||
})
|
||||
self.assertAlmostEqual(weight * np.sum(self._expected_losses), loss, 3)
|
||||
|
||||
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
|
||||
weight = np.asarray([2.0, 1.0]).reshape((2, 1))
|
||||
expectedes = np.multiply(weight, self._expected_losses)
|
||||
|
||||
loss = tf.contrib.losses.sum_of_pairwise_squares(
|
||||
predictions=tf.constant(self._predictions),
|
||||
targets=tf.constant(self._targets),
|
||||
weight=tf.constant(weight, shape=[2]))
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(np.sum(expectedes), loss.eval(), 3)
|
||||
|
||||
def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self):
|
||||
weight = np.asarray([1.2, 3.4]).reshape((2, 1))
|
||||
expectedes = np.multiply(weight, self._expected_losses)
|
||||
|
||||
tf_predictions = tf.placeholder(tf.float32, shape=self._predictions.shape)
|
||||
tf_targets = tf.placeholder(tf.int32, shape=self._targets.shape)
|
||||
loss = tf.contrib.losses.sum_of_pairwise_squares(
|
||||
predictions=tf_predictions,
|
||||
targets=tf_targets,
|
||||
weight=tf.constant(weight, shape=[2]))
|
||||
|
||||
with self.test_session() as sess:
|
||||
loss = sess.run(loss, feed_dict={
|
||||
tf_predictions: self._predictions,
|
||||
tf_targets: self._targets,
|
||||
})
|
||||
self.assertAlmostEqual(np.sum(expectedes), loss, 3)
|
||||
|
||||
def testLossWithAllZeroBatchSpecificWeights(self):
|
||||
weight = np.zeros((2, 1))
|
||||
loss = tf.contrib.losses.sum_of_pairwise_squares(
|
||||
predictions=tf.constant(self._predictions),
|
||||
targets=tf.constant(self._targets),
|
||||
weight=tf.constant(weight, shape=[2]))
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(0.0, loss.eval(), 3)
|
||||
|
||||
|
||||
class CosineDistanceLossTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._predictions = np.asarray([[1, 0, 0], # Batch 1
|
||||
[0, 0, -1],
|
||||
[1, 0, 0], # Batch 2
|
||||
[1, 0, 0],
|
||||
[0, 0, -1], # Batch 3
|
||||
[1, 0, 0]]).reshape((3, 2, 3))
|
||||
|
||||
self._targets = np.asarray([[1, 0, 0],
|
||||
[0, 0, 1],
|
||||
[0, 1, 0],
|
||||
[1, 0, 0],
|
||||
[0, 0, 1],
|
||||
[0, 1, 0]]).reshape((3, 2, 3))
|
||||
|
||||
def testValueErrorThrownWhenWeightIsNone(self):
|
||||
with self.test_session():
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.cosine_distance(
|
||||
predictions=tf.constant(self._targets),
|
||||
targets=tf.constant(self._targets),
|
||||
dim=2,
|
||||
weight=None)
|
||||
|
||||
def testAllCorrectNoWeights(self):
|
||||
loss = tf.contrib.losses.cosine_distance(
|
||||
predictions=tf.constant(self._targets),
|
||||
targets=tf.constant(self._targets),
|
||||
dim=2)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(0, loss.eval(), 5)
|
||||
|
||||
def testPartiallyCorrectWithIntegerValues(self):
|
||||
loss = tf.contrib.losses.cosine_distance(
|
||||
predictions=tf.constant(self._predictions),
|
||||
targets=tf.constant(self._targets),
|
||||
dim=2)
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(1, loss.eval(), 5)
|
||||
|
||||
def testPartiallyCorrectFloatingPointValues(self):
|
||||
predictions = np.matrix((
|
||||
'0.819031913261206 0.567041924552012 0.087465312324590;'
|
||||
'-0.665139432070255 -0.739487441769973 -0.103671883216994;'
|
||||
'0.707106781186548 -0.707106781186548 0'))
|
||||
targets = np.matrix((
|
||||
'0.819031913261206 0.567041924552012 0.087465312324590;'
|
||||
'0.665139432070255 0.739487441769973 0.103671883216994;'
|
||||
'0.707106781186548 0.707106781186548 0'))
|
||||
|
||||
tf_preds = tf.constant(predictions, shape=(3, 1, 3), dtype=tf.float32)
|
||||
tf_targets = tf.constant(targets, shape=(3, 1, 3), dtype=tf.float32)
|
||||
loss = tf.contrib.losses.cosine_distance(tf_preds, tf_targets, dim=2)
|
||||
|
||||
with self.test_session():
|
||||
self.assertAlmostEqual(1.0, loss.eval(), 5)
|
||||
|
||||
def testSampleSpecificWeights(self):
|
||||
loss = tf.contrib.losses.cosine_distance(
|
||||
predictions=tf.constant(self._predictions),
|
||||
targets=tf.constant(self._targets),
|
||||
dim=2,
|
||||
weight=tf.constant([1, 0, 0]))
|
||||
with self.test_session():
|
||||
self.assertEqual(1.0, loss.eval())
|
||||
|
||||
def testMeasurementSpecificWeights(self):
|
||||
loss = tf.contrib.losses.cosine_distance(
|
||||
predictions=tf.constant(self._predictions),
|
||||
targets=tf.constant(self._targets),
|
||||
dim=2,
|
||||
weight=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2)))
|
||||
with self.test_session():
|
||||
self.assertEqual(3.0 / 4.0, loss.eval())
|
||||
|
||||
def testValueErrorThrownWithShapelessPlaceholder(self):
|
||||
tf_predictions = tf.placeholder(tf.float32)
|
||||
with self.test_session():
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.losses.cosine_distance(
|
||||
predictions=tf_predictions,
|
||||
targets=tf.constant(self._targets),
|
||||
dim=2,
|
||||
weight=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2)))
|
||||
|
||||
def testMeasurementSpecificWeightsWithPlaceholderWithShape(self):
|
||||
tf_predictions = tf.placeholder(tf.float32, shape=self._targets.shape)
|
||||
loss = tf.contrib.losses.cosine_distance(
|
||||
predictions=tf_predictions,
|
||||
targets=tf.constant(self._targets),
|
||||
dim=2,
|
||||
weight=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2)))
|
||||
with self.test_session() as sess:
|
||||
loss = sess.run(loss, feed_dict={tf_predictions: self._predictions})
|
||||
self.assertEqual(3.0 / 4.0, loss)
|
||||
|
||||
def testZeroLossWhenAllSampleSpecificWeightsAreZero(self):
|
||||
loss = tf.contrib.losses.cosine_distance(
|
||||
predictions=tf.constant(self._predictions),
|
||||
targets=tf.constant(self._targets),
|
||||
dim=2,
|
||||
weight=tf.zeros((3,)))
|
||||
with self.test_session():
|
||||
self.assertEqual(0, loss.eval())
|
||||
|
||||
def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self):
|
||||
loss = tf.contrib.losses.cosine_distance(
|
||||
predictions=tf.constant(self._predictions),
|
||||
targets=tf.constant(self._targets),
|
||||
dim=2,
|
||||
weight=tf.zeros((3, 2)))
|
||||
with self.test_session():
|
||||
self.assertEqual(0, loss.eval())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user