Fixes GitHub #42458
Reference PR #37905 PiperOrigin-RevId: 327372634 Change-Id: I3bdcbffca4818f62b0c5227e065f896541c6b377
This commit is contained in:
parent
c82c43f658
commit
7e7641d95c
@ -1728,12 +1728,13 @@ def cosine_similarity(y_true, y_pred, axis=-1):
|
||||
class CosineSimilarity(LossFunctionWrapper):
|
||||
"""Computes the cosine similarity between labels and predictions.
|
||||
|
||||
Note that it is a negative quantity between -1 and 0, where 0 indicates
|
||||
orthogonality and values closer to -1 indicate greater similarity. This makes
|
||||
it usable as a loss function in a setting where you try to maximize the
|
||||
proximity between predictions and targets. If either `y_true` or `y_pred`
|
||||
is a zero vector, cosine similarity will be 0 regardless of the proximity
|
||||
between predictions and targets.
|
||||
Note that it is a number between -1 and 1. When it is a negative number
|
||||
between -1 and 0, 0 indicates orthogonality and values closer to -1
|
||||
indicate greater similarity. The values closer to 1 indicate greater
|
||||
dissimilarity. This makes it usable as a loss function in a setting
|
||||
where you try to maximize the proximity between predictions and targets.
|
||||
If either `y_true` or `y_pred` is a zero vector, cosine similarity will be 0
|
||||
regardless of the proximity between predictions and targets.
|
||||
|
||||
`loss = -sum(l2_norm(y_true) * l2_norm(y_pred))`
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user