Internal change

PiperOrigin-RevId: 299959121
Change-Id: If6e3e6355b4e2c71c965f6a13f78a5dca4f3ff32
This commit is contained in:
A. Unique TensorFlower 2020-03-09 16:06:35 -07:00 committed by TensorFlower Gardener
parent 97e1e8091a
commit 6b1d6f5343
2 changed files with 82 additions and 76 deletions

View File

@ -1479,35 +1479,10 @@ class SensitivitySpecificityBase(Metric):
K.batch_set_value( K.batch_set_value(
[(v, np.zeros((num_thresholds,))) for v in self.variables]) [(v, np.zeros((num_thresholds,))) for v in self.variables])
def _find_max_under_constraint(self, constrained, dependent, predicate):
"""Returns the maximum of dependent_statistic that satisfies the constraint.
Args:
constrained: Over these values the constraint
is specified. A rank-1 tensor.
dependent: From these values the maximum that satiesfies the
constraint is selected. Values in this tensor and in
`constrained` are linked by having the same threshold at each
position, hence this tensor must have the same shape.
predicate: A binary boolean functor to be applied to arguments
`constrained` and `self.value`, e.g. `tf.greater`.
Returns maximal dependent value, if no value satiesfies the constraint 0.0.
"""
feasible = array_ops.where(predicate(constrained, self.value))
feasible_exists = math_ops.greater(array_ops.size(feasible), 0)
def get_max():
return math_ops.reduce_max(array_ops.gather(dependent, feasible))
return control_flow_ops.cond(feasible_exists, get_max, lambda: 0.0)
@keras_export('keras.metrics.SensitivityAtSpecificity') @keras_export('keras.metrics.SensitivityAtSpecificity')
class SensitivityAtSpecificity(SensitivitySpecificityBase): class SensitivityAtSpecificity(SensitivitySpecificityBase):
"""Computes best sensitivity where specificity is >= specified value. """Computes the sensitivity at a given specificity.
the sensitivity at a given specificity.
`Sensitivity` measures the proportion of actual positives that are correctly `Sensitivity` measures the proportion of actual positives that are correctly
identified as such (tp / (tp + fn)). identified as such (tp / (tp + fn)).
@ -1527,16 +1502,16 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase):
Usage: Usage:
>>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5) >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.4, num_thresholds=1)
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
>>> m.result().numpy() >>> m.result().numpy()
0.5 0.5
>>> m.reset_states() >>> m.reset_states()
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
... sample_weight=[1, 1, 2, 2, 1]) ... sample_weight=[1, 0, 0, 1])
>>> m.result().numpy() >>> m.result().numpy()
0.333333 1.0
Usage with tf.keras API: Usage with tf.keras API:
@ -1567,12 +1542,20 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase):
specificity, num_thresholds=num_thresholds, name=name, dtype=dtype) specificity, num_thresholds=num_thresholds, name=name, dtype=dtype)
def result(self): def result(self):
# Calculate specificities at all the thresholds.
specificities = math_ops.div_no_nan( specificities = math_ops.div_no_nan(
self.true_negatives, self.true_negatives + self.false_positives) self.true_negatives, self.true_negatives + self.false_positives)
sensitivities = math_ops.div_no_nan(
self.true_positives, self.true_positives + self.false_negatives) # Find the index of the threshold where the specificity is closest to the
return self._find_max_under_constraint( # given specificity.
specificities, sensitivities, math_ops.greater_equal) min_index = math_ops.argmin(
math_ops.abs(specificities - self.value), axis=0)
min_index = math_ops.cast(min_index, dtypes.int32)
# Compute sensitivity at that index.
return math_ops.div_no_nan(
self.true_positives[min_index],
self.true_positives[min_index] + self.false_negatives[min_index])
def get_config(self): def get_config(self):
config = { config = {
@ -1585,7 +1568,7 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase):
@keras_export('keras.metrics.SpecificityAtSensitivity') @keras_export('keras.metrics.SpecificityAtSensitivity')
class SpecificityAtSensitivity(SensitivitySpecificityBase): class SpecificityAtSensitivity(SensitivitySpecificityBase):
"""Computes best specificity where sensitivity is >= specified value. """Computes the specificity at a given sensitivity.
`Sensitivity` measures the proportion of actual positives that are correctly `Sensitivity` measures the proportion of actual positives that are correctly
identified as such (tp / (tp + fn)). identified as such (tp / (tp + fn)).
@ -1605,16 +1588,16 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
Usage: Usage:
>>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5) >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.8, num_thresholds=1)
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
>>> m.result().numpy() >>> m.result().numpy()
0.66666667 1.0
>>> m.reset_states() >>> m.reset_states()
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
... sample_weight=[1, 1, 2, 2, 2]) ... sample_weight=[1, 0, 0, 1])
>>> m.result().numpy() >>> m.result().numpy()
0.5 1.0
Usage with tf.keras API: Usage with tf.keras API:
@ -1645,12 +1628,20 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
sensitivity, num_thresholds=num_thresholds, name=name, dtype=dtype) sensitivity, num_thresholds=num_thresholds, name=name, dtype=dtype)
def result(self): def result(self):
# Calculate sensitivities at all the thresholds.
sensitivities = math_ops.div_no_nan( sensitivities = math_ops.div_no_nan(
self.true_positives, self.true_positives + self.false_negatives) self.true_positives, self.true_positives + self.false_negatives)
specificities = math_ops.div_no_nan(
self.true_negatives, self.true_negatives + self.false_positives) # Find the index of the threshold where the sensitivity is closest to the
return self._find_max_under_constraint( # requested value.
sensitivities, specificities, math_ops.greater_equal) min_index = math_ops.argmin(
math_ops.abs(sensitivities - self.value), axis=0)
min_index = math_ops.cast(min_index, dtypes.int32)
# Compute specificity at that index.
return math_ops.div_no_nan(
self.true_negatives[min_index],
self.true_negatives[min_index] + self.false_positives[min_index])
def get_config(self): def get_config(self):
config = { config = {
@ -1663,7 +1654,7 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
@keras_export('keras.metrics.PrecisionAtRecall') @keras_export('keras.metrics.PrecisionAtRecall')
class PrecisionAtRecall(SensitivitySpecificityBase): class PrecisionAtRecall(SensitivitySpecificityBase):
"""Computes best precision where recall is >= specified value. """Computes the precision at a given recall.
This metric creates four local variables, `true_positives`, `true_negatives`, This metric creates four local variables, `true_positives`, `true_negatives`,
`false_positives` and `false_negatives` that are used to compute the `false_positives` and `false_negatives` that are used to compute the
@ -1675,16 +1666,16 @@ class PrecisionAtRecall(SensitivitySpecificityBase):
Usage: Usage:
>>> m = tf.keras.metrics.PrecisionAtRecall(0.5) >>> m = tf.keras.metrics.PrecisionAtRecall(0.8, num_thresholds=1)
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
>>> m.result().numpy() >>> m.result().numpy()
0.5 1.0
>>> m.reset_states() >>> m.reset_states()
>>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
... sample_weight=[2, 2, 2, 1, 1]) ... sample_weight=[1, 0, 0, 1])
>>> m.result().numpy() >>> m.result().numpy()
0.33333333 1.0
Usage with tf.keras API: Usage with tf.keras API:
@ -1718,12 +1709,20 @@ class PrecisionAtRecall(SensitivitySpecificityBase):
dtype=dtype) dtype=dtype)
def result(self): def result(self):
# Calculate recall at all the thresholds.
recalls = math_ops.div_no_nan( recalls = math_ops.div_no_nan(
self.true_positives, self.true_positives + self.false_negatives) self.true_positives, self.true_positives + self.false_negatives)
precisions = math_ops.div_no_nan(
self.true_positives, self.true_positives + self.false_positives) # Find the index of the threshold where the recall is closest to the
return self._find_max_under_constraint( # requested value.
recalls, precisions, math_ops.greater_equal) min_index = math_ops.argmin(
math_ops.abs(recalls - self.value), axis=0)
min_index = math_ops.cast(min_index, dtypes.int32)
# Compute precision at that index.
return math_ops.div_no_nan(
self.true_positives[min_index],
self.true_positives[min_index] + self.false_positives[min_index])
def get_config(self): def get_config(self):
config = {'num_thresholds': self.num_thresholds, 'recall': self.recall} config = {'num_thresholds': self.num_thresholds, 'recall': self.recall}
@ -1733,7 +1732,7 @@ class PrecisionAtRecall(SensitivitySpecificityBase):
@keras_export('keras.metrics.RecallAtPrecision') @keras_export('keras.metrics.RecallAtPrecision')
class RecallAtPrecision(SensitivitySpecificityBase): class RecallAtPrecision(SensitivitySpecificityBase):
"""Computes best recall where precision is >= specified value. """Computes the maximally achievable recall at a required precision.
For a given score-label-distribution the required precision might not For a given score-label-distribution the required precision might not
be achievable, in this case 0.0 is returned as recall. be achievable, in this case 0.0 is returned as recall.
@ -1748,7 +1747,7 @@ class RecallAtPrecision(SensitivitySpecificityBase):
Usage: Usage:
>>> m = tf.keras.metrics.RecallAtPrecision(0.8) >>> m = tf.keras.metrics.RecallAtPrecision(0.8, num_thresholds=1)
>>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
>>> m.result().numpy() >>> m.result().numpy()
0.5 0.5
@ -1791,12 +1790,21 @@ class RecallAtPrecision(SensitivitySpecificityBase):
dtype=dtype) dtype=dtype)
def result(self): def result(self):
# Calculate precision and recall at all the thresholds.
# All recalls are computed, because they are not a monotoneous function of
# precision and we want to search for the highest feasible recall.
precisions = math_ops.div_no_nan( precisions = math_ops.div_no_nan(
self.true_positives, self.true_positives + self.false_positives) self.true_positives, self.true_positives + self.false_positives)
recalls = math_ops.div_no_nan( recalls = math_ops.div_no_nan(
self.true_positives, self.true_positives + self.false_negatives) self.true_positives, self.true_positives + self.false_negatives)
return self._find_max_under_constraint( # Find best recall where the precision is as good as required.
precisions, recalls, math_ops.greater_equal) feasible = array_ops.where(math_ops.greater_equal(precisions, self.value))
feasible_exists = math_ops.greater(array_ops.size(feasible), 0)
best_recall = control_flow_ops.cond(
feasible_exists,
lambda: math_ops.reduce_max(array_ops.gather(recalls, feasible)),
lambda: 0.0)
return best_recall
def get_config(self): def get_config(self):
config = {'num_thresholds': self.num_thresholds, config = {'num_thresholds': self.num_thresholds,

View File

@ -877,15 +877,15 @@ class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase):
self.assertAlmostEqual(1, self.evaluate(result)) self.assertAlmostEqual(1, self.evaluate(result))
def test_unweighted_high_sensitivity(self): def test_unweighted_high_sensitivity(self):
s_obj = metrics.SpecificityAtSensitivity(1.0) s_obj = metrics.SpecificityAtSensitivity(0.8)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = constant_op.constant(pred_values, dtype=dtypes.float32) y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
y_true = constant_op.constant(label_values) y_true = constant_op.constant(label_values)
self.evaluate(variables.variables_initializer(s_obj.variables)) self.evaluate(variables.variables_initializer(s_obj.variables))
result = s_obj(y_true, y_pred) result = s_obj(y_true, y_pred)
self.assertAlmostEqual(0.2, self.evaluate(result)) self.assertAlmostEqual(0.4, self.evaluate(result))
def test_unweighted_low_sensitivity(self): def test_unweighted_low_sensitivity(self):
s_obj = metrics.SpecificityAtSensitivity(0.4) s_obj = metrics.SpecificityAtSensitivity(0.4)
@ -974,42 +974,40 @@ class PrecisionAtRecallTest(test.TestCase, parameterized.TestCase):
def test_unweighted_high_recall(self): def test_unweighted_high_recall(self):
s_obj = metrics.PrecisionAtRecall(0.8) s_obj = metrics.PrecisionAtRecall(0.8)
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] pred_values = [0.0, 0.1, 0.2, 0.3, 0.5, 0.4, 0.5, 0.6, 0.8, 0.9]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
# For a score between 0.4 and 0.5, we expect 0.8 precision, 0.8 recall.
y_pred = constant_op.constant(pred_values, dtype=dtypes.float32) y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
y_true = constant_op.constant(label_values) y_true = constant_op.constant(label_values)
self.evaluate(variables.variables_initializer(s_obj.variables)) self.evaluate(variables.variables_initializer(s_obj.variables))
result = s_obj(y_true, y_pred) result = s_obj(y_true, y_pred)
# For 0.5 < decision threshold < 0.6. self.assertAlmostEqual(0.8, self.evaluate(result))
self.assertAlmostEqual(2.0/3, self.evaluate(result))
def test_unweighted_low_recall(self): def test_unweighted_low_recall(self):
s_obj = metrics.PrecisionAtRecall(0.6) s_obj = metrics.PrecisionAtRecall(0.4)
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.15, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = constant_op.constant(pred_values, dtype=dtypes.float32) y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
y_true = constant_op.constant(label_values) y_true = constant_op.constant(label_values)
self.evaluate(variables.variables_initializer(s_obj.variables)) self.evaluate(variables.variables_initializer(s_obj.variables))
result = s_obj(y_true, y_pred) result = s_obj(y_true, y_pred)
# For 0.2 < decision threshold < 0.5. self.assertAlmostEqual(0.5, self.evaluate(result))
self.assertAlmostEqual(0.75, self.evaluate(result))
@parameterized.parameters([dtypes.bool, dtypes.int32, dtypes.float32]) @parameterized.parameters([dtypes.bool, dtypes.int32, dtypes.float32])
def test_weighted(self, label_dtype): def test_weighted(self, label_dtype):
s_obj = metrics.PrecisionAtRecall(7.0/8) s_obj = metrics.PrecisionAtRecall(0.4)
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
weight_values = [2, 1, 2, 1, 2, 1, 2, 2, 1, 2] weight_values = [2, 2, 1, 1, 1, 1, 1, 2, 2, 2]
y_pred = constant_op.constant(pred_values, dtype=dtypes.float32) y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
y_true = math_ops.cast(label_values, dtype=label_dtype) y_true = math_ops.cast(label_values, dtype=label_dtype)
weights = constant_op.constant(weight_values) weights = constant_op.constant(weight_values)
self.evaluate(variables.variables_initializer(s_obj.variables)) self.evaluate(variables.variables_initializer(s_obj.variables))
result = s_obj(y_true, y_pred, sample_weight=weights) result = s_obj(y_true, y_pred, sample_weight=weights)
# For 0.0 < decision threshold < 0.2. self.assertAlmostEqual(2./3., self.evaluate(result))
self.assertAlmostEqual(0.7, self.evaluate(result))
def test_invalid_sensitivity(self): def test_invalid_sensitivity(self):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(