Adding V2 categorical hinge loss and metric.
PiperOrigin-RevId: 226687801
This commit is contained in:
parent
ad6d0cb82b
commit
1853b08d9c
@ -463,6 +463,40 @@ class SquaredHinge(Loss):
|
||||
return squared_hinge(y_true, y_pred)
|
||||
|
||||
|
||||
class CategoricalHinge(Loss):
|
||||
"""Computes the categorical hinge loss between `y_true` and `y_pred`.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
ch = tf.losses.CategoricalHinge()
|
||||
loss = ch([0., 1., 1.], [1., 0., 1.])
|
||||
print('Loss: ', loss.numpy()) # Loss: 1.0
|
||||
```
|
||||
|
||||
Usage with tf.keras API:
|
||||
|
||||
```python
|
||||
model = keras.models.Model(inputs, outputs)
|
||||
model.compile('sgd', loss=tf.losses.CategoricalHinge())
|
||||
```
|
||||
"""
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
"""Calculates the categorical hinge loss.
|
||||
|
||||
Args:
|
||||
y_true: Ground truth values.
|
||||
y_pred: The predicted values.
|
||||
|
||||
Returns:
|
||||
Categorical hinge loss.
|
||||
"""
|
||||
y_pred = ops.convert_to_tensor(y_pred)
|
||||
y_true = math_ops.cast(y_true, y_pred.dtype)
|
||||
return categorical_hinge(y_true, y_pred)
|
||||
|
||||
|
||||
@keras_export('keras.metrics.mean_squared_error',
|
||||
'keras.metrics.mse',
|
||||
'keras.metrics.MSE',
|
||||
|
@ -937,5 +937,71 @@ class SquaredHingeTest(test.TestCase):
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0., 3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class CategoricalHingeTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
cat_hinge_obj = keras.losses.CategoricalHinge(
|
||||
reduction=losses_impl.ReductionV2.SUM, name='cat_hinge_loss')
|
||||
self.assertEqual(cat_hinge_obj.name, 'cat_hinge_loss')
|
||||
self.assertEqual(cat_hinge_obj.reduction, losses_impl.ReductionV2.SUM)
|
||||
|
||||
def test_unweighted(self):
|
||||
cat_hinge_obj = keras.losses.CategoricalHinge()
|
||||
y_true = constant_op.constant([1, 9, 2, -5], shape=(2, 2))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8],
|
||||
shape=(2, 2),
|
||||
dtype=dtypes.float32)
|
||||
loss = cat_hinge_obj(y_true, y_pred)
|
||||
|
||||
# pos = reduce_sum(y_true * y_pred) = [1*4+8*9, 12*2+8*-5] = [76, -16]
|
||||
# neg = reduce_max((1. - y_true) * y_pred) = [[0, -64], [-12, 48]] = [0, 48]
|
||||
# cat_hinge = max(0., neg - pos + 1.) = [0, 65]
|
||||
# reduced_loss = (0 + 65)/2 = 32.5
|
||||
self.assertAlmostEqual(self.evaluate(loss), 32.5, 3)
|
||||
|
||||
def test_scalar_weighted(self):
|
||||
cat_hinge_obj = keras.losses.CategoricalHinge()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
loss = cat_hinge_obj(y_true, y_pred, sample_weight=2.3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 83.95, 3)
|
||||
|
||||
# Verify we get the same output when the same input is given
|
||||
loss_2 = cat_hinge_obj(y_true, y_pred, sample_weight=2.3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), self.evaluate(loss_2), 3)
|
||||
|
||||
def test_sample_weighted(self):
|
||||
cat_hinge_obj = keras.losses.CategoricalHinge()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
|
||||
loss = cat_hinge_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 124.1, 3)
|
||||
|
||||
def test_timestep_weighted(self):
|
||||
cat_hinge_obj = keras.losses.CategoricalHinge()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3, 1),
|
||||
dtype=dtypes.float32)
|
||||
sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3))
|
||||
loss = cat_hinge_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 4.0, 3)
|
||||
|
||||
def test_zero_weighted(self):
|
||||
cat_hinge_obj = keras.losses.CategoricalHinge()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
loss = cat_hinge_obj(y_true, y_pred, sample_weight=0)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0., 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.keras.losses import binary_crossentropy
|
||||
from tensorflow.python.keras.losses import categorical_crossentropy
|
||||
from tensorflow.python.keras.losses import categorical_hinge
|
||||
from tensorflow.python.keras.losses import cosine_proximity
|
||||
from tensorflow.python.keras.losses import hinge
|
||||
from tensorflow.python.keras.losses import kullback_leibler_divergence
|
||||
@ -1450,6 +1451,38 @@ class SquaredHinge(MeanMetricWrapper):
|
||||
return super(SquaredHinge, cls).from_config(config)
|
||||
|
||||
|
||||
class CategoricalHinge(MeanMetricWrapper):
|
||||
"""Computes the categorical hinge metric between `y_true` and `y_pred`.
|
||||
|
||||
For example, if `y_true` is [0., 1., 1.], and `y_pred` is [1., 0., 1.]
|
||||
the categorical hinge metric value is 1.0.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
h = tf.keras.metrics.CategoricalHinge()
|
||||
h.update_state([0., 1., 1.], [1., 0., 1.])
|
||||
print('Final result: ', m.result().numpy()) # Final result: 1.0
|
||||
```
|
||||
|
||||
Usage with tf.keras API:
|
||||
|
||||
```python
|
||||
model = keras.models.Model(inputs, outputs)
|
||||
model.compile('sgd', loss=tf.keras.metrics.CategoricalHinge())
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, name='categorical_hinge', dtype=None):
|
||||
super(CategoricalHinge, self).__init__(categorical_hinge, name, dtype=dtype)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
if 'fn' in config:
|
||||
config.pop('fn')
|
||||
return super(CategoricalHinge, cls).from_config(config)
|
||||
|
||||
|
||||
def accuracy(y_true, y_pred):
|
||||
y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
|
||||
if y_true.dtype != y_pred.dtype:
|
||||
|
@ -1174,6 +1174,40 @@ class SquaredHingeTest(test.TestCase):
|
||||
self.assertAllClose(0.65714, self.evaluate(result), atol=1e-5)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class CategoricalHingeTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
cat_hinge_obj = metrics.CategoricalHinge(
|
||||
name='cat_hinge', dtype=dtypes.int32)
|
||||
self.assertEqual(cat_hinge_obj.name, 'cat_hinge')
|
||||
self.assertEqual(cat_hinge_obj._dtype, dtypes.int32)
|
||||
|
||||
def test_unweighted(self):
|
||||
cat_hinge_obj = metrics.CategoricalHinge()
|
||||
self.evaluate(variables.variables_initializer(cat_hinge_obj.variables))
|
||||
y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
|
||||
(1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
|
||||
y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
|
||||
(0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
|
||||
|
||||
update_op = cat_hinge_obj.update_state(y_true, y_pred)
|
||||
self.evaluate(update_op)
|
||||
result = cat_hinge_obj.result()
|
||||
self.assertAllClose(0.5, result, atol=1e-5)
|
||||
|
||||
def test_weighted(self):
|
||||
cat_hinge_obj = metrics.CategoricalHinge()
|
||||
self.evaluate(variables.variables_initializer(cat_hinge_obj.variables))
|
||||
y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
|
||||
(1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
|
||||
y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
|
||||
(0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
|
||||
sample_weight = constant_op.constant((1., 1.5, 2., 2.5))
|
||||
result = cat_hinge_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertAllClose(0.5, self.evaluate(result), atol=1e-5)
|
||||
|
||||
|
||||
def _get_model(compile_metrics):
|
||||
model_layers = [
|
||||
layers.Dense(3, activation='relu', kernel_initializer='ones'),
|
||||
|
Loading…
Reference in New Issue
Block a user