diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index 83318d6c57a..6b8d5601029 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -395,6 +395,40 @@ def cosine_proximity(y_true, y_pred): return -math_ops.reduce_sum(y_true * y_pred, axis=-1) +class CosineProximity(Loss): + """Computes the cosine distance between `y_true` and `y_pred`. + + Usage: + + ```python + cosine_loss = tf.losses.CosineProximity() + loss = cosine_loss([0., 1., 1.], [1., 0., 1.]) + print('Loss: ', loss.numpy()) # Loss: -0.5 + ``` + + Usage with tf.keras API: + + ```python + model = keras.models.Model(inputs, outputs) + model.compile('sgd', loss=tf.losses.CosineProximity()) + ``` + """ + + def call(self, y_true, y_pred): + """Calculates the cosine proximity loss. + + Args: + y_true: Ground truth values. + y_pred: The predicted values. + + Returns: + Cosine distance loss. + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = math_ops.cast(y_true, y_pred.dtype) + return cosine_proximity(y_true, y_pred) + + # Aliases. mse = MSE = mean_squared_error diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py index b5e9a24c997..cbf3c3524cc 100644 --- a/tensorflow/python/keras/losses_test.py +++ b/tensorflow/python/keras/losses_test.py @@ -442,5 +442,62 @@ class MeanSquaredLogarithmicErrorTest(test.TestCase): self.assertAlmostEqual(self.evaluate(loss), 0.0, 3) +@test_util.run_all_in_graph_and_eager_modes +class CosineProximityTest(test.TestCase): + + def test_config(self): + cosine_obj = keras.losses.CosineProximity( + reduction=losses_impl.ReductionV2.SUM, name='cosine_loss') + self.assertEqual(cosine_obj.name, 'cosine_loss') + self.assertEqual(cosine_obj.reduction, losses_impl.ReductionV2.SUM) + + def test_unweighted(self): + cosine_obj = keras.losses.CosineProximity() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + loss = cosine_obj(y_true, y_pred) + self.assertAlmostEqual(self.evaluate(loss), -0.18722, 3) + + def test_scalar_weighted(self): + cosine_obj = keras.losses.CosineProximity() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + loss = cosine_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(self.evaluate(loss), -0.43060, 3) + + def test_sample_weighted(self): + cosine_obj = keras.losses.CosineProximity() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1)) + loss = cosine_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(self.evaluate(loss), 0.15599, 3) + + def test_timestep_weighted(self): + cosine_obj = keras.losses.CosineProximity() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3, 1), + dtype=dtypes.float32) + sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3)) + loss = cosine_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(self.evaluate(loss), -2.0000, 3) + + def test_zero_weighted(self): + cosine_obj = keras.losses.CosineProximity() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + loss = cosine_obj(y_true, y_pred, sample_weight=0) + self.assertAlmostEqual(self.evaluate(loss), 0., 3) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 0519493a0af..3c2682e4c6f 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -1477,6 +1477,43 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase): self.tn[min_index] + self.fp[min_index]) +class CosineProximity(MeanMetricWrapper): + """Computes the cosine distance between the labels and predictions. + + For example, if `y_true` is [0, 1, 1], and `y_pred` is [1, 0, 1], the cosine + proximity is -0.5. + + This metric keeps the average cosine distance between `predictions` and + `labels` over a stream of data. + + Usage: + ```python + m = tf.metrics.CosineProximity() + m.update_state([0, 1, 1], [1, 0, 1]) + print('Final result: ', m.result().numpy()) # Final result: -0.5 + ``` + + Usage with tf.keras API: + + ```python + model = keras.models.Model(inputs, outputs) + model.compile( + 'sgd', + loss='mse', + metrics=[tf.metrics.CosineProximity()]) + ``` + """ + + def __init__(self, name='cosine_proximity', dtype=None): + super(CosineProximity, self).__init__(cosine, name, dtype=dtype) + + @classmethod + def from_config(cls, config): + if 'fn' in config: + config.pop('fn') + return super(CosineProximity, cls).from_config(config) + + def accuracy(y_true, y_pred): y_pred.get_shape().assert_is_compatible_with(y_true.get_shape()) if y_true.dtype != y_pred.dtype: diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index 9a88391dc16..92398acd8e6 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -1138,5 +1138,38 @@ class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(s_obj.tn), 25.) +@test_util.run_all_in_graph_and_eager_modes +class CosineProximityTest(test.TestCase): + + def test_config(self): + cosine_obj = metrics.CosineProximity(name='my_cos', dtype=dtypes.int32) + self.assertEqual(cosine_obj.name, 'my_cos') + self.assertEqual(cosine_obj._dtype, dtypes.int32) + + def test_unweighted(self): + cosine_obj = metrics.CosineProximity() + self.evaluate(variables.variables_initializer(cosine_obj.variables)) + + y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))) + y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))) + + update_op = cosine_obj.update_state(y_true, y_pred) + self.evaluate(update_op) + result = cosine_obj.result() + self.assertAllClose(-0.60723, result, atol=1e-5) + + def test_weighted(self): + cosine_obj = metrics.CosineProximity() + self.evaluate(variables.variables_initializer(cosine_obj.variables)) + y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))) + y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))) + sample_weight = constant_op.constant((1., 1.5, 2., 2.5)) + result = cosine_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(-0.59916, self.evaluate(result), atol=1e-5) + if __name__ == '__main__': test.main()