Move binary_hinge_loss under contrib/losses/
Change: 128892982
This commit is contained in:
parent
6fa593f46d
commit
0d534ca63d
@ -22,6 +22,7 @@ import inspect
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.contrib import losses
|
||||
from tensorflow.contrib import metrics as metrics_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import nn_ops
|
||||
|
||||
|
||||
def regression_target(label_name=None,
|
||||
@ -297,8 +297,17 @@ class _BinarySvmTargetColumn(_MultiClassTargetColumn):
|
||||
"""_TargetColumn for binary classification using SVMs."""
|
||||
|
||||
def __init__(self, label_name, weight_column_name):
|
||||
def loss_fn(logits, target):
|
||||
check_shape_op = logging_ops.Assert(
|
||||
math_ops.less_equal(array_ops.rank(target), 2),
|
||||
["target's shape should be either [batch_size, 1] or [batch_size]"])
|
||||
with ops.control_dependencies([check_shape_op]):
|
||||
target = array_ops.reshape(
|
||||
target, shape=[array_ops.shape(target)[0], 1])
|
||||
return losses.hinge_loss(logits, target)
|
||||
|
||||
super(_BinarySvmTargetColumn, self).__init__(
|
||||
loss_fn=_binary_hinge_loss,
|
||||
loss_fn=loss_fn,
|
||||
n_classes=2,
|
||||
label_name=label_name,
|
||||
weight_column_name=weight_column_name)
|
||||
@ -331,22 +340,6 @@ def _log_loss_with_two_classes(logits, target):
|
||||
return loss_vec
|
||||
|
||||
|
||||
# TODO(sibyl-vie3Poto): Move this to contrib/losses/python/losses/loss_ops.py.
|
||||
def _binary_hinge_loss(logits, target):
|
||||
"""Method that returns the loss vector for binary hinge loss."""
|
||||
check_shape_op = logging_ops.Assert(
|
||||
math_ops.less_equal(
|
||||
array_ops.rank(target), 2),
|
||||
["target's shape should be either [batch_size, 1] or [batch_size]"])
|
||||
with ops.control_dependencies([check_shape_op]):
|
||||
target = array_ops.reshape(target, shape=[array_ops.shape(target)[0], 1])
|
||||
# First need to convert binary labels to -1/1 labels (as floats).
|
||||
all_ones = array_ops.ones_like(logits)
|
||||
labels = math_ops.sub(2 * math_ops.to_float(target), all_ones)
|
||||
loss_vec = nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits)))
|
||||
return loss_vec
|
||||
|
||||
|
||||
def _softmax_cross_entropy_loss(logits, target):
|
||||
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
|
||||
# Check that we got int32/int64 for classification.
|
||||
|
@ -106,6 +106,7 @@ weighted average over the individual prediction errors:
|
||||
|
||||
@@absolute_difference
|
||||
@@add_loss
|
||||
@@hinge_loss
|
||||
@@cosine_distance
|
||||
@@get_losses
|
||||
@@get_regularization_losses
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import nn_ops
|
||||
|
||||
|
||||
__all__ = ["absolute_difference",
|
||||
@ -33,6 +34,7 @@ __all__ = ["absolute_difference",
|
||||
"get_losses",
|
||||
"get_regularization_losses",
|
||||
"get_total_loss",
|
||||
"hinge_loss",
|
||||
"log_loss",
|
||||
"sigmoid_cross_entropy",
|
||||
"softmax_cross_entropy",
|
||||
@ -410,6 +412,31 @@ def log_loss(predictions, targets, weight=1.0, epsilon=1e-7, scope=None):
|
||||
return _compute_weighted_loss(losses, weight)
|
||||
|
||||
|
||||
def hinge_loss(logits, target, scope=None):
|
||||
"""Method that returns the loss tensor for hinge loss.
|
||||
|
||||
Args:
|
||||
logits: The logits, a float tensor.
|
||||
target: The ground truth output tensor. Its shape should match the shape of
|
||||
logits. The values of the tensor are expected to be 0.0 or 1.0.
|
||||
scope: The scope for the operations performed in computing the loss.
|
||||
|
||||
Returns:
|
||||
A `Tensor` of same shape as logits and target representing the loss values
|
||||
across the batch.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shapes of `logits` and `target` don't match.
|
||||
"""
|
||||
with ops.op_scope([logits, target], scope, "hinge_loss") as scope:
|
||||
logits.get_shape().assert_is_compatible_with(target.get_shape())
|
||||
# We first need to convert binary labels to -1/1 labels (as floats).
|
||||
target = math_ops.to_float(target)
|
||||
all_ones = array_ops.ones_like(target)
|
||||
labels = math_ops.sub(2 * target, all_ones)
|
||||
return nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits)))
|
||||
|
||||
|
||||
def sum_of_squares(predictions, targets, weight=1.0, scope=None):
|
||||
"""Adds a Sum-of-Squares loss to the training procedure.
|
||||
|
||||
|
@ -499,6 +499,42 @@ class LogLossTest(tf.test.TestCase):
|
||||
self.assertAlmostEqual(0.0, loss.eval(), 3)
|
||||
|
||||
|
||||
class HingeLossTest(tf.test.TestCase):
|
||||
|
||||
def testIncompatibleShapes(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([[-1.0], [2.1]])
|
||||
target = tf.constant([0.0, 1.0])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = tf.contrib.losses.hinge_loss(logits, target).eval()
|
||||
|
||||
def testAllOutsideMargin(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([1.2, -1.4, -1.0, 2.1])
|
||||
target = tf.constant([1.0, 0.0, 0.0, 1.0])
|
||||
loss = tf.contrib.losses.hinge_loss(logits, target)
|
||||
self.assertAllClose(loss.eval(), [0.0, 0.0, 0.0, 0.0], atol=1e-3)
|
||||
|
||||
def testSomeInsideMargin(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([[-0.7], [-1.4], [1.4], [0.6]])
|
||||
target = tf.constant([[0.0], [0.0], [1.0], [1.0]])
|
||||
loss = tf.contrib.losses.hinge_loss(logits, target)
|
||||
# Examples 1 and 4 are on the correct side of the hyperplane but within
|
||||
# the margin so they incur some (small) loss.
|
||||
self.assertAllClose(loss.eval(), [[0.3], [0.0], [0.0], [0.4]], atol=1e-3)
|
||||
|
||||
def testSomeMisclassified(self):
|
||||
with self.test_session():
|
||||
logits = tf.constant([[[1.2], [0.4], [-1.0], [-1.1]]])
|
||||
target = tf.constant([[[1.0], [0.0], [0.0], [1.0]]])
|
||||
loss = tf.contrib.losses.hinge_loss(logits, target)
|
||||
# Examples 2 and 4 are on the wrong side of the hyperplane so they incur
|
||||
# some (fairly large) loss.
|
||||
self.assertAllClose(
|
||||
loss.eval(), [[[0.0], [1.4], [0.0], [2.1]]], atol=1e-3)
|
||||
|
||||
|
||||
class SumOfSquaresLossTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
Loading…
Reference in New Issue
Block a user