Internal change
PiperOrigin-RevId: 299959121 Change-Id: If6e3e6355b4e2c71c965f6a13f78a5dca4f3ff32
This commit is contained in:
parent
97e1e8091a
commit
6b1d6f5343
@ -1479,35 +1479,10 @@ class SensitivitySpecificityBase(Metric):
|
||||
K.batch_set_value(
|
||||
[(v, np.zeros((num_thresholds,))) for v in self.variables])
|
||||
|
||||
def _find_max_under_constraint(self, constrained, dependent, predicate):
|
||||
"""Returns the maximum of dependent_statistic that satisfies the constraint.
|
||||
|
||||
Args:
|
||||
constrained: Over these values the constraint
|
||||
is specified. A rank-1 tensor.
|
||||
dependent: From these values the maximum that satiesfies the
|
||||
constraint is selected. Values in this tensor and in
|
||||
`constrained` are linked by having the same threshold at each
|
||||
position, hence this tensor must have the same shape.
|
||||
predicate: A binary boolean functor to be applied to arguments
|
||||
`constrained` and `self.value`, e.g. `tf.greater`.
|
||||
|
||||
Returns maximal dependent value, if no value satiesfies the constraint 0.0.
|
||||
"""
|
||||
feasible = array_ops.where(predicate(constrained, self.value))
|
||||
feasible_exists = math_ops.greater(array_ops.size(feasible), 0)
|
||||
|
||||
def get_max():
|
||||
return math_ops.reduce_max(array_ops.gather(dependent, feasible))
|
||||
|
||||
return control_flow_ops.cond(feasible_exists, get_max, lambda: 0.0)
|
||||
|
||||
|
||||
@keras_export('keras.metrics.SensitivityAtSpecificity')
|
||||
class SensitivityAtSpecificity(SensitivitySpecificityBase):
|
||||
"""Computes best sensitivity where specificity is >= specified value.
|
||||
|
||||
the sensitivity at a given specificity.
|
||||
"""Computes the sensitivity at a given specificity.
|
||||
|
||||
`Sensitivity` measures the proportion of actual positives that are correctly
|
||||
identified as such (tp / (tp + fn)).
|
||||
@ -1527,16 +1502,16 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase):
|
||||
|
||||
Usage:
|
||||
|
||||
>>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5)
|
||||
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
|
||||
>>> m = tf.keras.metrics.SensitivityAtSpecificity(0.4, num_thresholds=1)
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
|
||||
>>> m.result().numpy()
|
||||
0.5
|
||||
|
||||
>>> m.reset_states()
|
||||
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
|
||||
... sample_weight=[1, 1, 2, 2, 1])
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
|
||||
... sample_weight=[1, 0, 0, 1])
|
||||
>>> m.result().numpy()
|
||||
0.333333
|
||||
1.0
|
||||
|
||||
Usage with tf.keras API:
|
||||
|
||||
@ -1567,12 +1542,20 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase):
|
||||
specificity, num_thresholds=num_thresholds, name=name, dtype=dtype)
|
||||
|
||||
def result(self):
|
||||
# Calculate specificities at all the thresholds.
|
||||
specificities = math_ops.div_no_nan(
|
||||
self.true_negatives, self.true_negatives + self.false_positives)
|
||||
sensitivities = math_ops.div_no_nan(
|
||||
self.true_positives, self.true_positives + self.false_negatives)
|
||||
return self._find_max_under_constraint(
|
||||
specificities, sensitivities, math_ops.greater_equal)
|
||||
|
||||
# Find the index of the threshold where the specificity is closest to the
|
||||
# given specificity.
|
||||
min_index = math_ops.argmin(
|
||||
math_ops.abs(specificities - self.value), axis=0)
|
||||
min_index = math_ops.cast(min_index, dtypes.int32)
|
||||
|
||||
# Compute sensitivity at that index.
|
||||
return math_ops.div_no_nan(
|
||||
self.true_positives[min_index],
|
||||
self.true_positives[min_index] + self.false_negatives[min_index])
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
@ -1585,7 +1568,7 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase):
|
||||
|
||||
@keras_export('keras.metrics.SpecificityAtSensitivity')
|
||||
class SpecificityAtSensitivity(SensitivitySpecificityBase):
|
||||
"""Computes best specificity where sensitivity is >= specified value.
|
||||
"""Computes the specificity at a given sensitivity.
|
||||
|
||||
`Sensitivity` measures the proportion of actual positives that are correctly
|
||||
identified as such (tp / (tp + fn)).
|
||||
@ -1605,16 +1588,16 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
|
||||
|
||||
Usage:
|
||||
|
||||
>>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5)
|
||||
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
|
||||
>>> m = tf.keras.metrics.SpecificityAtSensitivity(0.8, num_thresholds=1)
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
|
||||
>>> m.result().numpy()
|
||||
0.66666667
|
||||
1.0
|
||||
|
||||
>>> m.reset_states()
|
||||
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
|
||||
... sample_weight=[1, 1, 2, 2, 2])
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
|
||||
... sample_weight=[1, 0, 0, 1])
|
||||
>>> m.result().numpy()
|
||||
0.5
|
||||
1.0
|
||||
|
||||
Usage with tf.keras API:
|
||||
|
||||
@ -1645,12 +1628,20 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
|
||||
sensitivity, num_thresholds=num_thresholds, name=name, dtype=dtype)
|
||||
|
||||
def result(self):
|
||||
# Calculate sensitivities at all the thresholds.
|
||||
sensitivities = math_ops.div_no_nan(
|
||||
self.true_positives, self.true_positives + self.false_negatives)
|
||||
specificities = math_ops.div_no_nan(
|
||||
self.true_negatives, self.true_negatives + self.false_positives)
|
||||
return self._find_max_under_constraint(
|
||||
sensitivities, specificities, math_ops.greater_equal)
|
||||
|
||||
# Find the index of the threshold where the sensitivity is closest to the
|
||||
# requested value.
|
||||
min_index = math_ops.argmin(
|
||||
math_ops.abs(sensitivities - self.value), axis=0)
|
||||
min_index = math_ops.cast(min_index, dtypes.int32)
|
||||
|
||||
# Compute specificity at that index.
|
||||
return math_ops.div_no_nan(
|
||||
self.true_negatives[min_index],
|
||||
self.true_negatives[min_index] + self.false_positives[min_index])
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
@ -1663,7 +1654,7 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
|
||||
|
||||
@keras_export('keras.metrics.PrecisionAtRecall')
|
||||
class PrecisionAtRecall(SensitivitySpecificityBase):
|
||||
"""Computes best precision where recall is >= specified value.
|
||||
"""Computes the precision at a given recall.
|
||||
|
||||
This metric creates four local variables, `true_positives`, `true_negatives`,
|
||||
`false_positives` and `false_negatives` that are used to compute the
|
||||
@ -1675,16 +1666,16 @@ class PrecisionAtRecall(SensitivitySpecificityBase):
|
||||
|
||||
Usage:
|
||||
|
||||
>>> m = tf.keras.metrics.PrecisionAtRecall(0.5)
|
||||
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
|
||||
>>> m = tf.keras.metrics.PrecisionAtRecall(0.8, num_thresholds=1)
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
|
||||
>>> m.result().numpy()
|
||||
0.5
|
||||
1.0
|
||||
|
||||
>>> m.reset_states()
|
||||
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
|
||||
... sample_weight=[2, 2, 2, 1, 1])
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
|
||||
... sample_weight=[1, 0, 0, 1])
|
||||
>>> m.result().numpy()
|
||||
0.33333333
|
||||
1.0
|
||||
|
||||
Usage with tf.keras API:
|
||||
|
||||
@ -1718,12 +1709,20 @@ class PrecisionAtRecall(SensitivitySpecificityBase):
|
||||
dtype=dtype)
|
||||
|
||||
def result(self):
|
||||
# Calculate recall at all the thresholds.
|
||||
recalls = math_ops.div_no_nan(
|
||||
self.true_positives, self.true_positives + self.false_negatives)
|
||||
precisions = math_ops.div_no_nan(
|
||||
self.true_positives, self.true_positives + self.false_positives)
|
||||
return self._find_max_under_constraint(
|
||||
recalls, precisions, math_ops.greater_equal)
|
||||
|
||||
# Find the index of the threshold where the recall is closest to the
|
||||
# requested value.
|
||||
min_index = math_ops.argmin(
|
||||
math_ops.abs(recalls - self.value), axis=0)
|
||||
min_index = math_ops.cast(min_index, dtypes.int32)
|
||||
|
||||
# Compute precision at that index.
|
||||
return math_ops.div_no_nan(
|
||||
self.true_positives[min_index],
|
||||
self.true_positives[min_index] + self.false_positives[min_index])
|
||||
|
||||
def get_config(self):
|
||||
config = {'num_thresholds': self.num_thresholds, 'recall': self.recall}
|
||||
@ -1733,7 +1732,7 @@ class PrecisionAtRecall(SensitivitySpecificityBase):
|
||||
|
||||
@keras_export('keras.metrics.RecallAtPrecision')
|
||||
class RecallAtPrecision(SensitivitySpecificityBase):
|
||||
"""Computes best recall where precision is >= specified value.
|
||||
"""Computes the maximally achievable recall at a required precision.
|
||||
|
||||
For a given score-label-distribution the required precision might not
|
||||
be achievable, in this case 0.0 is returned as recall.
|
||||
@ -1748,7 +1747,7 @@ class RecallAtPrecision(SensitivitySpecificityBase):
|
||||
|
||||
Usage:
|
||||
|
||||
>>> m = tf.keras.metrics.RecallAtPrecision(0.8)
|
||||
>>> m = tf.keras.metrics.RecallAtPrecision(0.8, num_thresholds=1)
|
||||
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
|
||||
>>> m.result().numpy()
|
||||
0.5
|
||||
@ -1791,12 +1790,21 @@ class RecallAtPrecision(SensitivitySpecificityBase):
|
||||
dtype=dtype)
|
||||
|
||||
def result(self):
|
||||
# Calculate precision and recall at all the thresholds.
|
||||
# All recalls are computed, because they are not a monotoneous function of
|
||||
# precision and we want to search for the highest feasible recall.
|
||||
precisions = math_ops.div_no_nan(
|
||||
self.true_positives, self.true_positives + self.false_positives)
|
||||
recalls = math_ops.div_no_nan(
|
||||
self.true_positives, self.true_positives + self.false_negatives)
|
||||
return self._find_max_under_constraint(
|
||||
precisions, recalls, math_ops.greater_equal)
|
||||
# Find best recall where the precision is as good as required.
|
||||
feasible = array_ops.where(math_ops.greater_equal(precisions, self.value))
|
||||
feasible_exists = math_ops.greater(array_ops.size(feasible), 0)
|
||||
best_recall = control_flow_ops.cond(
|
||||
feasible_exists,
|
||||
lambda: math_ops.reduce_max(array_ops.gather(recalls, feasible)),
|
||||
lambda: 0.0)
|
||||
return best_recall
|
||||
|
||||
def get_config(self):
|
||||
config = {'num_thresholds': self.num_thresholds,
|
||||
|
@ -877,15 +877,15 @@ class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAlmostEqual(1, self.evaluate(result))
|
||||
|
||||
def test_unweighted_high_sensitivity(self):
|
||||
s_obj = metrics.SpecificityAtSensitivity(1.0)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
|
||||
s_obj = metrics.SpecificityAtSensitivity(0.8)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
|
||||
y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
|
||||
y_true = constant_op.constant(label_values)
|
||||
self.evaluate(variables.variables_initializer(s_obj.variables))
|
||||
result = s_obj(y_true, y_pred)
|
||||
self.assertAlmostEqual(0.2, self.evaluate(result))
|
||||
self.assertAlmostEqual(0.4, self.evaluate(result))
|
||||
|
||||
def test_unweighted_low_sensitivity(self):
|
||||
s_obj = metrics.SpecificityAtSensitivity(0.4)
|
||||
@ -974,42 +974,40 @@ class PrecisionAtRecallTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_unweighted_high_recall(self):
|
||||
s_obj = metrics.PrecisionAtRecall(0.8)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.5, 0.4, 0.5, 0.6, 0.8, 0.9]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
|
||||
# For a score between 0.4 and 0.5, we expect 0.8 precision, 0.8 recall.
|
||||
y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
|
||||
y_true = constant_op.constant(label_values)
|
||||
self.evaluate(variables.variables_initializer(s_obj.variables))
|
||||
result = s_obj(y_true, y_pred)
|
||||
# For 0.5 < decision threshold < 0.6.
|
||||
self.assertAlmostEqual(2.0/3, self.evaluate(result))
|
||||
self.assertAlmostEqual(0.8, self.evaluate(result))
|
||||
|
||||
def test_unweighted_low_recall(self):
|
||||
s_obj = metrics.PrecisionAtRecall(0.6)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
|
||||
s_obj = metrics.PrecisionAtRecall(0.4)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.15, 0.25, 0.26, 0.26]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
|
||||
y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
|
||||
y_true = constant_op.constant(label_values)
|
||||
self.evaluate(variables.variables_initializer(s_obj.variables))
|
||||
result = s_obj(y_true, y_pred)
|
||||
# For 0.2 < decision threshold < 0.5.
|
||||
self.assertAlmostEqual(0.75, self.evaluate(result))
|
||||
self.assertAlmostEqual(0.5, self.evaluate(result))
|
||||
|
||||
@parameterized.parameters([dtypes.bool, dtypes.int32, dtypes.float32])
|
||||
def test_weighted(self, label_dtype):
|
||||
s_obj = metrics.PrecisionAtRecall(7.0/8)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
|
||||
s_obj = metrics.PrecisionAtRecall(0.4)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
weight_values = [2, 1, 2, 1, 2, 1, 2, 2, 1, 2]
|
||||
weight_values = [2, 2, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
|
||||
y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
|
||||
y_true = math_ops.cast(label_values, dtype=label_dtype)
|
||||
weights = constant_op.constant(weight_values)
|
||||
self.evaluate(variables.variables_initializer(s_obj.variables))
|
||||
result = s_obj(y_true, y_pred, sample_weight=weights)
|
||||
# For 0.0 < decision threshold < 0.2.
|
||||
self.assertAlmostEqual(0.7, self.evaluate(result))
|
||||
self.assertAlmostEqual(2./3., self.evaluate(result))
|
||||
|
||||
def test_invalid_sensitivity(self):
|
||||
with self.assertRaisesRegexp(
|
||||
|
Loading…
Reference in New Issue
Block a user