Fix bug in SensitivitySpecificityBase derived metrics.
Sub-classes of `SensitivitySpecificityBase` compute the value of one statistic, given a constraint on another statistic (e.g. compute recall at specified precision). Previously the documentation stated "Computes <statistic A> at a given <statistic B> <X>.", there was no guaranteed and consistent behaviour in case the specified <statistic B> cannot be assume value <X> on the provided sample of scores and labels (e.g. required recall of 0.7, but only either 0.6 or 0.8 can be reached). This change refines the function behaviour to "Computes best <statistic A> where <statistic B> >= <X>". This caters to common use cases of operating binary classifiers, with a requirement to e.g. maintain a minimal precision and maximise the recall - it is important not to report a recall from an operating point that undershoots the required precision (previously the closest precision would be selected, even it if is smaller). Because this changes (refines) the semantics of the metrics, some expected values in unittests etc. must be adapted. PiperOrigin-RevId: 299919680 Change-Id: I496ffaef51e599ccf82222a02e2cab55ade35f1a
This commit is contained in:
parent
2d51302640
commit
da5ab77e05
@ -1479,10 +1479,35 @@ class SensitivitySpecificityBase(Metric):
|
||||
K.batch_set_value(
|
||||
[(v, np.zeros((num_thresholds,))) for v in self.variables])
|
||||
|
||||
def _find_max_under_constraint(self, constrained, dependent, predicate):
|
||||
"""Returns the maximum of dependent_statistic that satisfies the constraint.
|
||||
|
||||
Args:
|
||||
constrained: Over these values the constraint
|
||||
is specified. A rank-1 tensor.
|
||||
dependent: From these values the maximum that satiesfies the
|
||||
constraint is selected. Values in this tensor and in
|
||||
`constrained` are linked by having the same threshold at each
|
||||
position, hence this tensor must have the same shape.
|
||||
predicate: A binary boolean functor to be applied to arguments
|
||||
`constrained` and `self.value`, e.g. `tf.greater`.
|
||||
|
||||
Returns maximal dependent value, if no value satiesfies the constraint 0.0.
|
||||
"""
|
||||
feasible = array_ops.where(predicate(constrained, self.value))
|
||||
feasible_exists = math_ops.greater(array_ops.size(feasible), 0)
|
||||
|
||||
def get_max():
|
||||
return math_ops.reduce_max(array_ops.gather(dependent, feasible))
|
||||
|
||||
return control_flow_ops.cond(feasible_exists, get_max, lambda: 0.0)
|
||||
|
||||
|
||||
@keras_export('keras.metrics.SensitivityAtSpecificity')
|
||||
class SensitivityAtSpecificity(SensitivitySpecificityBase):
|
||||
"""Computes the sensitivity at a given specificity.
|
||||
"""Computes best sensitivity where specificity is >= specified value.
|
||||
|
||||
the sensitivity at a given specificity.
|
||||
|
||||
`Sensitivity` measures the proportion of actual positives that are correctly
|
||||
identified as such (tp / (tp + fn)).
|
||||
@ -1502,16 +1527,16 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase):
|
||||
|
||||
Usage:
|
||||
|
||||
>>> m = tf.keras.metrics.SensitivityAtSpecificity(0.4, num_thresholds=1)
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
|
||||
>>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5)
|
||||
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
|
||||
>>> m.result().numpy()
|
||||
0.5
|
||||
|
||||
>>> m.reset_states()
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
|
||||
... sample_weight=[1, 0, 0, 1])
|
||||
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
|
||||
... sample_weight=[1, 1, 2, 2, 1])
|
||||
>>> m.result().numpy()
|
||||
1.0
|
||||
0.333333
|
||||
|
||||
Usage with tf.keras API:
|
||||
|
||||
@ -1542,20 +1567,12 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase):
|
||||
specificity, num_thresholds=num_thresholds, name=name, dtype=dtype)
|
||||
|
||||
def result(self):
|
||||
# Calculate specificities at all the thresholds.
|
||||
specificities = math_ops.div_no_nan(
|
||||
self.true_negatives, self.true_negatives + self.false_positives)
|
||||
|
||||
# Find the index of the threshold where the specificity is closest to the
|
||||
# given specificity.
|
||||
min_index = math_ops.argmin(
|
||||
math_ops.abs(specificities - self.value), axis=0)
|
||||
min_index = math_ops.cast(min_index, dtypes.int32)
|
||||
|
||||
# Compute sensitivity at that index.
|
||||
return math_ops.div_no_nan(
|
||||
self.true_positives[min_index],
|
||||
self.true_positives[min_index] + self.false_negatives[min_index])
|
||||
sensitivities = math_ops.div_no_nan(
|
||||
self.true_positives, self.true_positives + self.false_negatives)
|
||||
return self._find_max_under_constraint(
|
||||
specificities, sensitivities, math_ops.greater_equal)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
@ -1568,7 +1585,7 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase):
|
||||
|
||||
@keras_export('keras.metrics.SpecificityAtSensitivity')
|
||||
class SpecificityAtSensitivity(SensitivitySpecificityBase):
|
||||
"""Computes the specificity at a given sensitivity.
|
||||
"""Computes best specificity where sensitivity is >= specified value.
|
||||
|
||||
`Sensitivity` measures the proportion of actual positives that are correctly
|
||||
identified as such (tp / (tp + fn)).
|
||||
@ -1588,16 +1605,16 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
|
||||
|
||||
Usage:
|
||||
|
||||
>>> m = tf.keras.metrics.SpecificityAtSensitivity(0.8, num_thresholds=1)
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
|
||||
>>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5)
|
||||
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
|
||||
>>> m.result().numpy()
|
||||
1.0
|
||||
0.66666667
|
||||
|
||||
>>> m.reset_states()
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
|
||||
... sample_weight=[1, 0, 0, 1])
|
||||
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
|
||||
... sample_weight=[1, 1, 2, 2, 2])
|
||||
>>> m.result().numpy()
|
||||
1.0
|
||||
0.5
|
||||
|
||||
Usage with tf.keras API:
|
||||
|
||||
@ -1628,20 +1645,12 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
|
||||
sensitivity, num_thresholds=num_thresholds, name=name, dtype=dtype)
|
||||
|
||||
def result(self):
|
||||
# Calculate sensitivities at all the thresholds.
|
||||
sensitivities = math_ops.div_no_nan(
|
||||
self.true_positives, self.true_positives + self.false_negatives)
|
||||
|
||||
# Find the index of the threshold where the sensitivity is closest to the
|
||||
# requested value.
|
||||
min_index = math_ops.argmin(
|
||||
math_ops.abs(sensitivities - self.value), axis=0)
|
||||
min_index = math_ops.cast(min_index, dtypes.int32)
|
||||
|
||||
# Compute specificity at that index.
|
||||
return math_ops.div_no_nan(
|
||||
self.true_negatives[min_index],
|
||||
self.true_negatives[min_index] + self.false_positives[min_index])
|
||||
specificities = math_ops.div_no_nan(
|
||||
self.true_negatives, self.true_negatives + self.false_positives)
|
||||
return self._find_max_under_constraint(
|
||||
sensitivities, specificities, math_ops.greater_equal)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
@ -1654,7 +1663,7 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
|
||||
|
||||
@keras_export('keras.metrics.PrecisionAtRecall')
|
||||
class PrecisionAtRecall(SensitivitySpecificityBase):
|
||||
"""Computes the precision at a given recall.
|
||||
"""Computes best precision where recall is >= specified value.
|
||||
|
||||
This metric creates four local variables, `true_positives`, `true_negatives`,
|
||||
`false_positives` and `false_negatives` that are used to compute the
|
||||
@ -1666,16 +1675,16 @@ class PrecisionAtRecall(SensitivitySpecificityBase):
|
||||
|
||||
Usage:
|
||||
|
||||
>>> m = tf.keras.metrics.PrecisionAtRecall(0.8, num_thresholds=1)
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
|
||||
>>> m = tf.keras.metrics.PrecisionAtRecall(0.5)
|
||||
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
|
||||
>>> m.result().numpy()
|
||||
1.0
|
||||
0.5
|
||||
|
||||
>>> m.reset_states()
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
|
||||
... sample_weight=[1, 0, 0, 1])
|
||||
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
|
||||
... sample_weight=[2, 2, 2, 1, 1])
|
||||
>>> m.result().numpy()
|
||||
1.0
|
||||
0.33333333
|
||||
|
||||
Usage with tf.keras API:
|
||||
|
||||
@ -1709,20 +1718,12 @@ class PrecisionAtRecall(SensitivitySpecificityBase):
|
||||
dtype=dtype)
|
||||
|
||||
def result(self):
|
||||
# Calculate recall at all the thresholds.
|
||||
recalls = math_ops.div_no_nan(
|
||||
self.true_positives, self.true_positives + self.false_negatives)
|
||||
|
||||
# Find the index of the threshold where the recall is closest to the
|
||||
# requested value.
|
||||
min_index = math_ops.argmin(
|
||||
math_ops.abs(recalls - self.value), axis=0)
|
||||
min_index = math_ops.cast(min_index, dtypes.int32)
|
||||
|
||||
# Compute precision at that index.
|
||||
return math_ops.div_no_nan(
|
||||
self.true_positives[min_index],
|
||||
self.true_positives[min_index] + self.false_positives[min_index])
|
||||
precisions = math_ops.div_no_nan(
|
||||
self.true_positives, self.true_positives + self.false_positives)
|
||||
return self._find_max_under_constraint(
|
||||
recalls, precisions, math_ops.greater_equal)
|
||||
|
||||
def get_config(self):
|
||||
config = {'num_thresholds': self.num_thresholds, 'recall': self.recall}
|
||||
@ -1732,7 +1733,7 @@ class PrecisionAtRecall(SensitivitySpecificityBase):
|
||||
|
||||
@keras_export('keras.metrics.RecallAtPrecision')
|
||||
class RecallAtPrecision(SensitivitySpecificityBase):
|
||||
"""Computes the maximally achievable recall at a required precision.
|
||||
"""Computes best recall where precision is >= specified value.
|
||||
|
||||
For a given score-label-distribution the required precision might not
|
||||
be achievable, in this case 0.0 is returned as recall.
|
||||
@ -1747,7 +1748,7 @@ class RecallAtPrecision(SensitivitySpecificityBase):
|
||||
|
||||
Usage:
|
||||
|
||||
>>> m = tf.keras.metrics.RecallAtPrecision(0.8, num_thresholds=1)
|
||||
>>> m = tf.keras.metrics.RecallAtPrecision(0.8)
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
|
||||
>>> m.result().numpy()
|
||||
0.5
|
||||
@ -1790,21 +1791,12 @@ class RecallAtPrecision(SensitivitySpecificityBase):
|
||||
dtype=dtype)
|
||||
|
||||
def result(self):
|
||||
# Calculate precision and recall at all the thresholds.
|
||||
# All recalls are computed, because they are not a monotoneous function of
|
||||
# precision and we want to search for the highest feasible recall.
|
||||
precisions = math_ops.div_no_nan(
|
||||
self.true_positives, self.true_positives + self.false_positives)
|
||||
recalls = math_ops.div_no_nan(
|
||||
self.true_positives, self.true_positives + self.false_negatives)
|
||||
# Find best recall where the precision is as good as required.
|
||||
feasible = array_ops.where(math_ops.greater_equal(precisions, self.value))
|
||||
feasible_exists = math_ops.greater(array_ops.size(feasible), 0)
|
||||
best_recall = control_flow_ops.cond(
|
||||
feasible_exists,
|
||||
lambda: math_ops.reduce_max(array_ops.gather(recalls, feasible)),
|
||||
lambda: 0.0)
|
||||
return best_recall
|
||||
return self._find_max_under_constraint(
|
||||
precisions, recalls, math_ops.greater_equal)
|
||||
|
||||
def get_config(self):
|
||||
config = {'num_thresholds': self.num_thresholds,
|
||||
|
@ -877,15 +877,15 @@ class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAlmostEqual(1, self.evaluate(result))
|
||||
|
||||
def test_unweighted_high_sensitivity(self):
|
||||
s_obj = metrics.SpecificityAtSensitivity(0.8)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]
|
||||
s_obj = metrics.SpecificityAtSensitivity(1.0)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
|
||||
y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
|
||||
y_true = constant_op.constant(label_values)
|
||||
self.evaluate(variables.variables_initializer(s_obj.variables))
|
||||
result = s_obj(y_true, y_pred)
|
||||
self.assertAlmostEqual(0.4, self.evaluate(result))
|
||||
self.assertAlmostEqual(0.2, self.evaluate(result))
|
||||
|
||||
def test_unweighted_low_sensitivity(self):
|
||||
s_obj = metrics.SpecificityAtSensitivity(0.4)
|
||||
@ -974,40 +974,42 @@ class PrecisionAtRecallTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_unweighted_high_recall(self):
|
||||
s_obj = metrics.PrecisionAtRecall(0.8)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.5, 0.4, 0.5, 0.6, 0.8, 0.9]
|
||||
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
|
||||
# For a score between 0.4 and 0.5, we expect 0.8 precision, 0.8 recall.
|
||||
y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
|
||||
y_true = constant_op.constant(label_values)
|
||||
self.evaluate(variables.variables_initializer(s_obj.variables))
|
||||
result = s_obj(y_true, y_pred)
|
||||
self.assertAlmostEqual(0.8, self.evaluate(result))
|
||||
# For 0.5 < decision threshold < 0.6.
|
||||
self.assertAlmostEqual(2.0/3, self.evaluate(result))
|
||||
|
||||
def test_unweighted_low_recall(self):
|
||||
s_obj = metrics.PrecisionAtRecall(0.4)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.15, 0.25, 0.26, 0.26]
|
||||
s_obj = metrics.PrecisionAtRecall(0.6)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
|
||||
y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
|
||||
y_true = constant_op.constant(label_values)
|
||||
self.evaluate(variables.variables_initializer(s_obj.variables))
|
||||
result = s_obj(y_true, y_pred)
|
||||
self.assertAlmostEqual(0.5, self.evaluate(result))
|
||||
# For 0.2 < decision threshold < 0.5.
|
||||
self.assertAlmostEqual(0.75, self.evaluate(result))
|
||||
|
||||
@parameterized.parameters([dtypes.bool, dtypes.int32, dtypes.float32])
|
||||
def test_weighted(self, label_dtype):
|
||||
s_obj = metrics.PrecisionAtRecall(0.4)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
|
||||
s_obj = metrics.PrecisionAtRecall(7.0/8)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
weight_values = [2, 2, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
weight_values = [2, 1, 2, 1, 2, 1, 2, 2, 1, 2]
|
||||
|
||||
y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
|
||||
y_true = math_ops.cast(label_values, dtype=label_dtype)
|
||||
weights = constant_op.constant(weight_values)
|
||||
self.evaluate(variables.variables_initializer(s_obj.variables))
|
||||
result = s_obj(y_true, y_pred, sample_weight=weights)
|
||||
self.assertAlmostEqual(2./3., self.evaluate(result))
|
||||
# For 0.0 < decision threshold < 0.2.
|
||||
self.assertAlmostEqual(0.7, self.evaluate(result))
|
||||
|
||||
def test_invalid_sensitivity(self):
|
||||
with self.assertRaisesRegexp(
|
||||
|
Loading…
Reference in New Issue
Block a user