From 0d534ca63d92596bd5efd4808e37cf39ba623fcf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 30 Jul 2016 07:32:55 -0800 Subject: [PATCH] Move binary_hinge_loss under contrib/losses/ Change: 128892982 --- .../layers/python/layers/target_column.py | 29 ++++++--------- .../contrib/losses/python/losses/__init__.py | 1 + .../contrib/losses/python/losses/loss_ops.py | 27 ++++++++++++++ .../losses/python/losses/loss_ops_test.py | 36 +++++++++++++++++++ 4 files changed, 75 insertions(+), 18 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py index 9f321895025..08280446723 100644 --- a/tensorflow/contrib/layers/python/layers/target_column.py +++ b/tensorflow/contrib/layers/python/layers/target_column.py @@ -22,6 +22,7 @@ import inspect import six +from tensorflow.contrib import losses from tensorflow.contrib import metrics as metrics_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops def regression_target(label_name=None, @@ -297,8 +297,17 @@ class _BinarySvmTargetColumn(_MultiClassTargetColumn): """_TargetColumn for binary classification using SVMs.""" def __init__(self, label_name, weight_column_name): + def loss_fn(logits, target): + check_shape_op = logging_ops.Assert( + math_ops.less_equal(array_ops.rank(target), 2), + ["target's shape should be either [batch_size, 1] or [batch_size]"]) + with ops.control_dependencies([check_shape_op]): + target = array_ops.reshape( + target, shape=[array_ops.shape(target)[0], 1]) + return losses.hinge_loss(logits, target) + super(_BinarySvmTargetColumn, self).__init__( - loss_fn=_binary_hinge_loss, + loss_fn=loss_fn, n_classes=2, label_name=label_name, weight_column_name=weight_column_name) @@ -331,22 +340,6 @@ def _log_loss_with_two_classes(logits, target): return loss_vec -# TODO(sibyl-vie3Poto): Move this to contrib/losses/python/losses/loss_ops.py. -def _binary_hinge_loss(logits, target): - """Method that returns the loss vector for binary hinge loss.""" - check_shape_op = logging_ops.Assert( - math_ops.less_equal( - array_ops.rank(target), 2), - ["target's shape should be either [batch_size, 1] or [batch_size]"]) - with ops.control_dependencies([check_shape_op]): - target = array_ops.reshape(target, shape=[array_ops.shape(target)[0], 1]) - # First need to convert binary labels to -1/1 labels (as floats). - all_ones = array_ops.ones_like(logits) - labels = math_ops.sub(2 * math_ops.to_float(target), all_ones) - loss_vec = nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits))) - return loss_vec - - def _softmax_cross_entropy_loss(logits, target): # sigmoid_cross_entropy_with_logits requires [batch_size, 1] target. # Check that we got int32/int64 for classification. diff --git a/tensorflow/contrib/losses/python/losses/__init__.py b/tensorflow/contrib/losses/python/losses/__init__.py index 081d47e4b55..d8181632bf8 100644 --- a/tensorflow/contrib/losses/python/losses/__init__.py +++ b/tensorflow/contrib/losses/python/losses/__init__.py @@ -106,6 +106,7 @@ weighted average over the individual prediction errors: @@absolute_difference @@add_loss +@@hinge_loss @@cosine_distance @@get_losses @@get_regularization_losses diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 99aab8b44c2..597e6aeda93 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops __all__ = ["absolute_difference", @@ -33,6 +34,7 @@ __all__ = ["absolute_difference", "get_losses", "get_regularization_losses", "get_total_loss", + "hinge_loss", "log_loss", "sigmoid_cross_entropy", "softmax_cross_entropy", @@ -410,6 +412,31 @@ def log_loss(predictions, targets, weight=1.0, epsilon=1e-7, scope=None): return _compute_weighted_loss(losses, weight) +def hinge_loss(logits, target, scope=None): + """Method that returns the loss tensor for hinge loss. + + Args: + logits: The logits, a float tensor. + target: The ground truth output tensor. Its shape should match the shape of + logits. The values of the tensor are expected to be 0.0 or 1.0. + scope: The scope for the operations performed in computing the loss. + + Returns: + A `Tensor` of same shape as logits and target representing the loss values + across the batch. + + Raises: + ValueError: If the shapes of `logits` and `target` don't match. + """ + with ops.op_scope([logits, target], scope, "hinge_loss") as scope: + logits.get_shape().assert_is_compatible_with(target.get_shape()) + # We first need to convert binary labels to -1/1 labels (as floats). + target = math_ops.to_float(target) + all_ones = array_ops.ones_like(target) + labels = math_ops.sub(2 * target, all_ones) + return nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits))) + + def sum_of_squares(predictions, targets, weight=1.0, scope=None): """Adds a Sum-of-Squares loss to the training procedure. diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py index 49460ec2279..824c24451be 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py @@ -499,6 +499,42 @@ class LogLossTest(tf.test.TestCase): self.assertAlmostEqual(0.0, loss.eval(), 3) +class HingeLossTest(tf.test.TestCase): + + def testIncompatibleShapes(self): + with self.test_session(): + logits = tf.constant([[-1.0], [2.1]]) + target = tf.constant([0.0, 1.0]) + with self.assertRaises(ValueError): + _ = tf.contrib.losses.hinge_loss(logits, target).eval() + + def testAllOutsideMargin(self): + with self.test_session(): + logits = tf.constant([1.2, -1.4, -1.0, 2.1]) + target = tf.constant([1.0, 0.0, 0.0, 1.0]) + loss = tf.contrib.losses.hinge_loss(logits, target) + self.assertAllClose(loss.eval(), [0.0, 0.0, 0.0, 0.0], atol=1e-3) + + def testSomeInsideMargin(self): + with self.test_session(): + logits = tf.constant([[-0.7], [-1.4], [1.4], [0.6]]) + target = tf.constant([[0.0], [0.0], [1.0], [1.0]]) + loss = tf.contrib.losses.hinge_loss(logits, target) + # Examples 1 and 4 are on the correct side of the hyperplane but within + # the margin so they incur some (small) loss. + self.assertAllClose(loss.eval(), [[0.3], [0.0], [0.0], [0.4]], atol=1e-3) + + def testSomeMisclassified(self): + with self.test_session(): + logits = tf.constant([[[1.2], [0.4], [-1.0], [-1.1]]]) + target = tf.constant([[[1.0], [0.0], [0.0], [1.0]]]) + loss = tf.contrib.losses.hinge_loss(logits, target) + # Examples 2 and 4 are on the wrong side of the hyperplane so they incur + # some (fairly large) loss. + self.assertAllClose( + loss.eval(), [[[0.0], [1.4], [0.0], [2.1]]], atol=1e-3) + + class SumOfSquaresLossTest(tf.test.TestCase): def setUp(self):