Add CosineProximity to metrics and losses
PiperOrigin-RevId: 223407786
This commit is contained in:
parent
b0422005fa
commit
1f2eee4e23
@ -395,6 +395,40 @@ def cosine_proximity(y_true, y_pred):
|
||||
return -math_ops.reduce_sum(y_true * y_pred, axis=-1)
|
||||
|
||||
|
||||
class CosineProximity(Loss):
|
||||
"""Computes the cosine distance between `y_true` and `y_pred`.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
cosine_loss = tf.losses.CosineProximity()
|
||||
loss = cosine_loss([0., 1., 1.], [1., 0., 1.])
|
||||
print('Loss: ', loss.numpy()) # Loss: -0.5
|
||||
```
|
||||
|
||||
Usage with tf.keras API:
|
||||
|
||||
```python
|
||||
model = keras.models.Model(inputs, outputs)
|
||||
model.compile('sgd', loss=tf.losses.CosineProximity())
|
||||
```
|
||||
"""
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
"""Calculates the cosine proximity loss.
|
||||
|
||||
Args:
|
||||
y_true: Ground truth values.
|
||||
y_pred: The predicted values.
|
||||
|
||||
Returns:
|
||||
Cosine distance loss.
|
||||
"""
|
||||
y_pred = ops.convert_to_tensor(y_pred)
|
||||
y_true = math_ops.cast(y_true, y_pred.dtype)
|
||||
return cosine_proximity(y_true, y_pred)
|
||||
|
||||
|
||||
# Aliases.
|
||||
|
||||
mse = MSE = mean_squared_error
|
||||
|
@ -442,5 +442,62 @@ class MeanSquaredLogarithmicErrorTest(test.TestCase):
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class CosineProximityTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
cosine_obj = keras.losses.CosineProximity(
|
||||
reduction=losses_impl.ReductionV2.SUM, name='cosine_loss')
|
||||
self.assertEqual(cosine_obj.name, 'cosine_loss')
|
||||
self.assertEqual(cosine_obj.reduction, losses_impl.ReductionV2.SUM)
|
||||
|
||||
def test_unweighted(self):
|
||||
cosine_obj = keras.losses.CosineProximity()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
loss = cosine_obj(y_true, y_pred)
|
||||
self.assertAlmostEqual(self.evaluate(loss), -0.18722, 3)
|
||||
|
||||
def test_scalar_weighted(self):
|
||||
cosine_obj = keras.losses.CosineProximity()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
loss = cosine_obj(y_true, y_pred, sample_weight=2.3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), -0.43060, 3)
|
||||
|
||||
def test_sample_weighted(self):
|
||||
cosine_obj = keras.losses.CosineProximity()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
|
||||
loss = cosine_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0.15599, 3)
|
||||
|
||||
def test_timestep_weighted(self):
|
||||
cosine_obj = keras.losses.CosineProximity()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3, 1),
|
||||
dtype=dtypes.float32)
|
||||
sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3))
|
||||
loss = cosine_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertAlmostEqual(self.evaluate(loss), -2.0000, 3)
|
||||
|
||||
def test_zero_weighted(self):
|
||||
cosine_obj = keras.losses.CosineProximity()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
loss = cosine_obj(y_true, y_pred, sample_weight=0)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0., 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -1477,6 +1477,43 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
|
||||
self.tn[min_index] + self.fp[min_index])
|
||||
|
||||
|
||||
class CosineProximity(MeanMetricWrapper):
|
||||
"""Computes the cosine distance between the labels and predictions.
|
||||
|
||||
For example, if `y_true` is [0, 1, 1], and `y_pred` is [1, 0, 1], the cosine
|
||||
proximity is -0.5.
|
||||
|
||||
This metric keeps the average cosine distance between `predictions` and
|
||||
`labels` over a stream of data.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
m = tf.metrics.CosineProximity()
|
||||
m.update_state([0, 1, 1], [1, 0, 1])
|
||||
print('Final result: ', m.result().numpy()) # Final result: -0.5
|
||||
```
|
||||
|
||||
Usage with tf.keras API:
|
||||
|
||||
```python
|
||||
model = keras.models.Model(inputs, outputs)
|
||||
model.compile(
|
||||
'sgd',
|
||||
loss='mse',
|
||||
metrics=[tf.metrics.CosineProximity()])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, name='cosine_proximity', dtype=None):
|
||||
super(CosineProximity, self).__init__(cosine, name, dtype=dtype)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
if 'fn' in config:
|
||||
config.pop('fn')
|
||||
return super(CosineProximity, cls).from_config(config)
|
||||
|
||||
|
||||
def accuracy(y_true, y_pred):
|
||||
y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
|
||||
if y_true.dtype != y_pred.dtype:
|
||||
|
@ -1138,5 +1138,38 @@ class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(s_obj.tn), 25.)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class CosineProximityTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
cosine_obj = metrics.CosineProximity(name='my_cos', dtype=dtypes.int32)
|
||||
self.assertEqual(cosine_obj.name, 'my_cos')
|
||||
self.assertEqual(cosine_obj._dtype, dtypes.int32)
|
||||
|
||||
def test_unweighted(self):
|
||||
cosine_obj = metrics.CosineProximity()
|
||||
self.evaluate(variables.variables_initializer(cosine_obj.variables))
|
||||
|
||||
y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
|
||||
(1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
|
||||
y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
|
||||
(0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
|
||||
|
||||
update_op = cosine_obj.update_state(y_true, y_pred)
|
||||
self.evaluate(update_op)
|
||||
result = cosine_obj.result()
|
||||
self.assertAllClose(-0.60723, result, atol=1e-5)
|
||||
|
||||
def test_weighted(self):
|
||||
cosine_obj = metrics.CosineProximity()
|
||||
self.evaluate(variables.variables_initializer(cosine_obj.variables))
|
||||
y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
|
||||
(1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
|
||||
y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
|
||||
(0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
|
||||
sample_weight = constant_op.constant((1., 1.5, 2., 2.5))
|
||||
result = cosine_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertAllClose(-0.59916, self.evaluate(result), atol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user