diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index 5039a52ef19..51ae935bb2c 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -463,6 +463,40 @@ class SquaredHinge(Loss): return squared_hinge(y_true, y_pred) +class CategoricalHinge(Loss): + """Computes the categorical hinge loss between `y_true` and `y_pred`. + + Usage: + + ```python + ch = tf.losses.CategoricalHinge() + loss = ch([0., 1., 1.], [1., 0., 1.]) + print('Loss: ', loss.numpy()) # Loss: 1.0 + ``` + + Usage with tf.keras API: + + ```python + model = keras.models.Model(inputs, outputs) + model.compile('sgd', loss=tf.losses.CategoricalHinge()) + ``` + """ + + def call(self, y_true, y_pred): + """Calculates the categorical hinge loss. + + Args: + y_true: Ground truth values. + y_pred: The predicted values. + + Returns: + Categorical hinge loss. + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = math_ops.cast(y_true, y_pred.dtype) + return categorical_hinge(y_true, y_pred) + + @keras_export('keras.metrics.mean_squared_error', 'keras.metrics.mse', 'keras.metrics.MSE', diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py index c0ba2baf569..19ed7c8ed9d 100644 --- a/tensorflow/python/keras/losses_test.py +++ b/tensorflow/python/keras/losses_test.py @@ -937,5 +937,71 @@ class SquaredHingeTest(test.TestCase): self.assertAlmostEqual(self.evaluate(loss), 0., 3) +@test_util.run_all_in_graph_and_eager_modes +class CategoricalHingeTest(test.TestCase): + + def test_config(self): + cat_hinge_obj = keras.losses.CategoricalHinge( + reduction=losses_impl.ReductionV2.SUM, name='cat_hinge_loss') + self.assertEqual(cat_hinge_obj.name, 'cat_hinge_loss') + self.assertEqual(cat_hinge_obj.reduction, losses_impl.ReductionV2.SUM) + + def test_unweighted(self): + cat_hinge_obj = keras.losses.CategoricalHinge() + y_true = constant_op.constant([1, 9, 2, -5], shape=(2, 2)) + y_pred = constant_op.constant([4, 8, 12, 8], + shape=(2, 2), + dtype=dtypes.float32) + loss = cat_hinge_obj(y_true, y_pred) + + # pos = reduce_sum(y_true * y_pred) = [1*4+8*9, 12*2+8*-5] = [76, -16] + # neg = reduce_max((1. - y_true) * y_pred) = [[0, -64], [-12, 48]] = [0, 48] + # cat_hinge = max(0., neg - pos + 1.) = [0, 65] + # reduced_loss = (0 + 65)/2 = 32.5 + self.assertAlmostEqual(self.evaluate(loss), 32.5, 3) + + def test_scalar_weighted(self): + cat_hinge_obj = keras.losses.CategoricalHinge() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + loss = cat_hinge_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(self.evaluate(loss), 83.95, 3) + + # Verify we get the same output when the same input is given + loss_2 = cat_hinge_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(self.evaluate(loss), self.evaluate(loss_2), 3) + + def test_sample_weighted(self): + cat_hinge_obj = keras.losses.CategoricalHinge() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1)) + loss = cat_hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(self.evaluate(loss), 124.1, 3) + + def test_timestep_weighted(self): + cat_hinge_obj = keras.losses.CategoricalHinge() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3, 1), + dtype=dtypes.float32) + sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3)) + loss = cat_hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(self.evaluate(loss), 4.0, 3) + + def test_zero_weighted(self): + cat_hinge_obj = keras.losses.CategoricalHinge() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + loss = cat_hinge_obj(y_true, y_pred, sample_weight=0) + self.assertAlmostEqual(self.evaluate(loss), 0., 3) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index a7b5f2e75ea..70733235061 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -33,6 +33,7 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.losses import binary_crossentropy from tensorflow.python.keras.losses import categorical_crossentropy +from tensorflow.python.keras.losses import categorical_hinge from tensorflow.python.keras.losses import cosine_proximity from tensorflow.python.keras.losses import hinge from tensorflow.python.keras.losses import kullback_leibler_divergence @@ -1450,6 +1451,38 @@ class SquaredHinge(MeanMetricWrapper): return super(SquaredHinge, cls).from_config(config) +class CategoricalHinge(MeanMetricWrapper): + """Computes the categorical hinge metric between `y_true` and `y_pred`. + + For example, if `y_true` is [0., 1., 1.], and `y_pred` is [1., 0., 1.] + the categorical hinge metric value is 1.0. + + Usage: + + ```python + h = tf.keras.metrics.CategoricalHinge() + h.update_state([0., 1., 1.], [1., 0., 1.]) + print('Final result: ', m.result().numpy()) # Final result: 1.0 + ``` + + Usage with tf.keras API: + + ```python + model = keras.models.Model(inputs, outputs) + model.compile('sgd', loss=tf.keras.metrics.CategoricalHinge()) + ``` + """ + + def __init__(self, name='categorical_hinge', dtype=None): + super(CategoricalHinge, self).__init__(categorical_hinge, name, dtype=dtype) + + @classmethod + def from_config(cls, config): + if 'fn' in config: + config.pop('fn') + return super(CategoricalHinge, cls).from_config(config) + + def accuracy(y_true, y_pred): y_pred.get_shape().assert_is_compatible_with(y_true.get_shape()) if y_true.dtype != y_pred.dtype: diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index 896b39992b3..a6d714dcfb4 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -1174,6 +1174,40 @@ class SquaredHingeTest(test.TestCase): self.assertAllClose(0.65714, self.evaluate(result), atol=1e-5) +@test_util.run_all_in_graph_and_eager_modes +class CategoricalHingeTest(test.TestCase): + + def test_config(self): + cat_hinge_obj = metrics.CategoricalHinge( + name='cat_hinge', dtype=dtypes.int32) + self.assertEqual(cat_hinge_obj.name, 'cat_hinge') + self.assertEqual(cat_hinge_obj._dtype, dtypes.int32) + + def test_unweighted(self): + cat_hinge_obj = metrics.CategoricalHinge() + self.evaluate(variables.variables_initializer(cat_hinge_obj.variables)) + y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))) + y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))) + + update_op = cat_hinge_obj.update_state(y_true, y_pred) + self.evaluate(update_op) + result = cat_hinge_obj.result() + self.assertAllClose(0.5, result, atol=1e-5) + + def test_weighted(self): + cat_hinge_obj = metrics.CategoricalHinge() + self.evaluate(variables.variables_initializer(cat_hinge_obj.variables)) + y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))) + y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))) + sample_weight = constant_op.constant((1., 1.5, 2., 2.5)) + result = cat_hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(0.5, self.evaluate(result), atol=1e-5) + + def _get_model(compile_metrics): model_layers = [ layers.Dense(3, activation='relu', kernel_initializer='ones'),