Add CosineProximity to metrics and losses

PiperOrigin-RevId: 223407786
This commit is contained in:
Katherine Wu 2018-11-29 14:06:19 -08:00 committed by TensorFlower Gardener
parent b0422005fa
commit 1f2eee4e23
4 changed files with 161 additions and 0 deletions

View File

@ -395,6 +395,40 @@ def cosine_proximity(y_true, y_pred):
return -math_ops.reduce_sum(y_true * y_pred, axis=-1)
class CosineProximity(Loss):
"""Computes the cosine distance between `y_true` and `y_pred`.
Usage:
```python
cosine_loss = tf.losses.CosineProximity()
loss = cosine_loss([0., 1., 1.], [1., 0., 1.])
print('Loss: ', loss.numpy()) # Loss: -0.5
```
Usage with tf.keras API:
```python
model = keras.models.Model(inputs, outputs)
model.compile('sgd', loss=tf.losses.CosineProximity())
```
"""
def call(self, y_true, y_pred):
"""Calculates the cosine proximity loss.
Args:
y_true: Ground truth values.
y_pred: The predicted values.
Returns:
Cosine distance loss.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
return cosine_proximity(y_true, y_pred)
# Aliases.
mse = MSE = mean_squared_error

View File

@ -442,5 +442,62 @@ class MeanSquaredLogarithmicErrorTest(test.TestCase):
self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
@test_util.run_all_in_graph_and_eager_modes
class CosineProximityTest(test.TestCase):
def test_config(self):
cosine_obj = keras.losses.CosineProximity(
reduction=losses_impl.ReductionV2.SUM, name='cosine_loss')
self.assertEqual(cosine_obj.name, 'cosine_loss')
self.assertEqual(cosine_obj.reduction, losses_impl.ReductionV2.SUM)
def test_unweighted(self):
cosine_obj = keras.losses.CosineProximity()
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
shape=(2, 3),
dtype=dtypes.float32)
loss = cosine_obj(y_true, y_pred)
self.assertAlmostEqual(self.evaluate(loss), -0.18722, 3)
def test_scalar_weighted(self):
cosine_obj = keras.losses.CosineProximity()
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
shape=(2, 3),
dtype=dtypes.float32)
loss = cosine_obj(y_true, y_pred, sample_weight=2.3)
self.assertAlmostEqual(self.evaluate(loss), -0.43060, 3)
def test_sample_weighted(self):
cosine_obj = keras.losses.CosineProximity()
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
shape=(2, 3),
dtype=dtypes.float32)
sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
loss = cosine_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAlmostEqual(self.evaluate(loss), 0.15599, 3)
def test_timestep_weighted(self):
cosine_obj = keras.losses.CosineProximity()
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
shape=(2, 3, 1),
dtype=dtypes.float32)
sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3))
loss = cosine_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAlmostEqual(self.evaluate(loss), -2.0000, 3)
def test_zero_weighted(self):
cosine_obj = keras.losses.CosineProximity()
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
shape=(2, 3),
dtype=dtypes.float32)
loss = cosine_obj(y_true, y_pred, sample_weight=0)
self.assertAlmostEqual(self.evaluate(loss), 0., 3)
if __name__ == '__main__':
test.main()

View File

@ -1477,6 +1477,43 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
self.tn[min_index] + self.fp[min_index])
class CosineProximity(MeanMetricWrapper):
"""Computes the cosine distance between the labels and predictions.
For example, if `y_true` is [0, 1, 1], and `y_pred` is [1, 0, 1], the cosine
proximity is -0.5.
This metric keeps the average cosine distance between `predictions` and
`labels` over a stream of data.
Usage:
```python
m = tf.metrics.CosineProximity()
m.update_state([0, 1, 1], [1, 0, 1])
print('Final result: ', m.result().numpy()) # Final result: -0.5
```
Usage with tf.keras API:
```python
model = keras.models.Model(inputs, outputs)
model.compile(
'sgd',
loss='mse',
metrics=[tf.metrics.CosineProximity()])
```
"""
def __init__(self, name='cosine_proximity', dtype=None):
super(CosineProximity, self).__init__(cosine, name, dtype=dtype)
@classmethod
def from_config(cls, config):
if 'fn' in config:
config.pop('fn')
return super(CosineProximity, cls).from_config(config)
def accuracy(y_true, y_pred):
y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
if y_true.dtype != y_pred.dtype:

View File

@ -1138,5 +1138,38 @@ class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(s_obj.tn), 25.)
@test_util.run_all_in_graph_and_eager_modes
class CosineProximityTest(test.TestCase):
def test_config(self):
cosine_obj = metrics.CosineProximity(name='my_cos', dtype=dtypes.int32)
self.assertEqual(cosine_obj.name, 'my_cos')
self.assertEqual(cosine_obj._dtype, dtypes.int32)
def test_unweighted(self):
cosine_obj = metrics.CosineProximity()
self.evaluate(variables.variables_initializer(cosine_obj.variables))
y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
(1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
(0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
update_op = cosine_obj.update_state(y_true, y_pred)
self.evaluate(update_op)
result = cosine_obj.result()
self.assertAllClose(-0.60723, result, atol=1e-5)
def test_weighted(self):
cosine_obj = metrics.CosineProximity()
self.evaluate(variables.variables_initializer(cosine_obj.variables))
y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
(1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
(0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
sample_weight = constant_op.constant((1., 1.5, 2., 2.5))
result = cosine_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(-0.59916, self.evaluate(result), atol=1e-5)
if __name__ == '__main__':
test.main()