diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index ea98d53accd..81abaabbe32 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -1479,35 +1479,10 @@ class SensitivitySpecificityBase(Metric): K.batch_set_value( [(v, np.zeros((num_thresholds,))) for v in self.variables]) - def _find_max_under_constraint(self, constrained, dependent, predicate): - """Returns the maximum of dependent_statistic that satisfies the constraint. - - Args: - constrained: Over these values the constraint - is specified. A rank-1 tensor. - dependent: From these values the maximum that satiesfies the - constraint is selected. Values in this tensor and in - `constrained` are linked by having the same threshold at each - position, hence this tensor must have the same shape. - predicate: A binary boolean functor to be applied to arguments - `constrained` and `self.value`, e.g. `tf.greater`. - - Returns maximal dependent value, if no value satiesfies the constraint 0.0. - """ - feasible = array_ops.where(predicate(constrained, self.value)) - feasible_exists = math_ops.greater(array_ops.size(feasible), 0) - - def get_max(): - return math_ops.reduce_max(array_ops.gather(dependent, feasible)) - - return control_flow_ops.cond(feasible_exists, get_max, lambda: 0.0) - @keras_export('keras.metrics.SensitivityAtSpecificity') class SensitivityAtSpecificity(SensitivitySpecificityBase): - """Computes best sensitivity where specificity is >= specified value. - - the sensitivity at a given specificity. + """Computes the sensitivity at a given specificity. `Sensitivity` measures the proportion of actual positives that are correctly identified as such (tp / (tp + fn)). @@ -1527,16 +1502,16 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase): Usage: - >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5) - >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) + >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.4, num_thresholds=1) + >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) >>> m.result().numpy() 0.5 >>> m.reset_states() - >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], - ... sample_weight=[1, 1, 2, 2, 1]) + >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], + ... sample_weight=[1, 0, 0, 1]) >>> m.result().numpy() - 0.333333 + 1.0 Usage with tf.keras API: @@ -1567,12 +1542,20 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase): specificity, num_thresholds=num_thresholds, name=name, dtype=dtype) def result(self): + # Calculate specificities at all the thresholds. specificities = math_ops.div_no_nan( self.true_negatives, self.true_negatives + self.false_positives) - sensitivities = math_ops.div_no_nan( - self.true_positives, self.true_positives + self.false_negatives) - return self._find_max_under_constraint( - specificities, sensitivities, math_ops.greater_equal) + + # Find the index of the threshold where the specificity is closest to the + # given specificity. + min_index = math_ops.argmin( + math_ops.abs(specificities - self.value), axis=0) + min_index = math_ops.cast(min_index, dtypes.int32) + + # Compute sensitivity at that index. + return math_ops.div_no_nan( + self.true_positives[min_index], + self.true_positives[min_index] + self.false_negatives[min_index]) def get_config(self): config = { @@ -1585,7 +1568,7 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase): @keras_export('keras.metrics.SpecificityAtSensitivity') class SpecificityAtSensitivity(SensitivitySpecificityBase): - """Computes best specificity where sensitivity is >= specified value. + """Computes the specificity at a given sensitivity. `Sensitivity` measures the proportion of actual positives that are correctly identified as such (tp / (tp + fn)). @@ -1605,16 +1588,16 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase): Usage: - >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5) - >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) + >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.8, num_thresholds=1) + >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) >>> m.result().numpy() - 0.66666667 + 1.0 >>> m.reset_states() - >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], - ... sample_weight=[1, 1, 2, 2, 2]) + >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], + ... sample_weight=[1, 0, 0, 1]) >>> m.result().numpy() - 0.5 + 1.0 Usage with tf.keras API: @@ -1645,12 +1628,20 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase): sensitivity, num_thresholds=num_thresholds, name=name, dtype=dtype) def result(self): + # Calculate sensitivities at all the thresholds. sensitivities = math_ops.div_no_nan( self.true_positives, self.true_positives + self.false_negatives) - specificities = math_ops.div_no_nan( - self.true_negatives, self.true_negatives + self.false_positives) - return self._find_max_under_constraint( - sensitivities, specificities, math_ops.greater_equal) + + # Find the index of the threshold where the sensitivity is closest to the + # requested value. + min_index = math_ops.argmin( + math_ops.abs(sensitivities - self.value), axis=0) + min_index = math_ops.cast(min_index, dtypes.int32) + + # Compute specificity at that index. + return math_ops.div_no_nan( + self.true_negatives[min_index], + self.true_negatives[min_index] + self.false_positives[min_index]) def get_config(self): config = { @@ -1663,7 +1654,7 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase): @keras_export('keras.metrics.PrecisionAtRecall') class PrecisionAtRecall(SensitivitySpecificityBase): - """Computes best precision where recall is >= specified value. + """Computes the precision at a given recall. This metric creates four local variables, `true_positives`, `true_negatives`, `false_positives` and `false_negatives` that are used to compute the @@ -1675,16 +1666,16 @@ class PrecisionAtRecall(SensitivitySpecificityBase): Usage: - >>> m = tf.keras.metrics.PrecisionAtRecall(0.5) - >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) + >>> m = tf.keras.metrics.PrecisionAtRecall(0.8, num_thresholds=1) + >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) >>> m.result().numpy() - 0.5 + 1.0 >>> m.reset_states() - >>> _ = m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], - ... sample_weight=[2, 2, 2, 1, 1]) + >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], + ... sample_weight=[1, 0, 0, 1]) >>> m.result().numpy() - 0.33333333 + 1.0 Usage with tf.keras API: @@ -1718,12 +1709,20 @@ class PrecisionAtRecall(SensitivitySpecificityBase): dtype=dtype) def result(self): + # Calculate recall at all the thresholds. recalls = math_ops.div_no_nan( self.true_positives, self.true_positives + self.false_negatives) - precisions = math_ops.div_no_nan( - self.true_positives, self.true_positives + self.false_positives) - return self._find_max_under_constraint( - recalls, precisions, math_ops.greater_equal) + + # Find the index of the threshold where the recall is closest to the + # requested value. + min_index = math_ops.argmin( + math_ops.abs(recalls - self.value), axis=0) + min_index = math_ops.cast(min_index, dtypes.int32) + + # Compute precision at that index. + return math_ops.div_no_nan( + self.true_positives[min_index], + self.true_positives[min_index] + self.false_positives[min_index]) def get_config(self): config = {'num_thresholds': self.num_thresholds, 'recall': self.recall} @@ -1733,7 +1732,7 @@ class PrecisionAtRecall(SensitivitySpecificityBase): @keras_export('keras.metrics.RecallAtPrecision') class RecallAtPrecision(SensitivitySpecificityBase): - """Computes best recall where precision is >= specified value. + """Computes the maximally achievable recall at a required precision. For a given score-label-distribution the required precision might not be achievable, in this case 0.0 is returned as recall. @@ -1748,7 +1747,7 @@ class RecallAtPrecision(SensitivitySpecificityBase): Usage: - >>> m = tf.keras.metrics.RecallAtPrecision(0.8) + >>> m = tf.keras.metrics.RecallAtPrecision(0.8, num_thresholds=1) >>> _ = m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) >>> m.result().numpy() 0.5 @@ -1791,12 +1790,21 @@ class RecallAtPrecision(SensitivitySpecificityBase): dtype=dtype) def result(self): + # Calculate precision and recall at all the thresholds. + # All recalls are computed, because they are not a monotoneous function of + # precision and we want to search for the highest feasible recall. precisions = math_ops.div_no_nan( self.true_positives, self.true_positives + self.false_positives) recalls = math_ops.div_no_nan( self.true_positives, self.true_positives + self.false_negatives) - return self._find_max_under_constraint( - precisions, recalls, math_ops.greater_equal) + # Find best recall where the precision is as good as required. + feasible = array_ops.where(math_ops.greater_equal(precisions, self.value)) + feasible_exists = math_ops.greater(array_ops.size(feasible), 0) + best_recall = control_flow_ops.cond( + feasible_exists, + lambda: math_ops.reduce_max(array_ops.gather(recalls, feasible)), + lambda: 0.0) + return best_recall def get_config(self): config = {'num_thresholds': self.num_thresholds, diff --git a/tensorflow/python/keras/metrics_confusion_matrix_test.py b/tensorflow/python/keras/metrics_confusion_matrix_test.py index 186c3f0328f..2ea6282cb27 100644 --- a/tensorflow/python/keras/metrics_confusion_matrix_test.py +++ b/tensorflow/python/keras/metrics_confusion_matrix_test.py @@ -877,15 +877,15 @@ class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase): self.assertAlmostEqual(1, self.evaluate(result)) def test_unweighted_high_sensitivity(self): - s_obj = metrics.SpecificityAtSensitivity(1.0) - pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + s_obj = metrics.SpecificityAtSensitivity(0.8) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] y_pred = constant_op.constant(pred_values, dtype=dtypes.float32) y_true = constant_op.constant(label_values) self.evaluate(variables.variables_initializer(s_obj.variables)) result = s_obj(y_true, y_pred) - self.assertAlmostEqual(0.2, self.evaluate(result)) + self.assertAlmostEqual(0.4, self.evaluate(result)) def test_unweighted_low_sensitivity(self): s_obj = metrics.SpecificityAtSensitivity(0.4) @@ -974,42 +974,40 @@ class PrecisionAtRecallTest(test.TestCase, parameterized.TestCase): def test_unweighted_high_recall(self): s_obj = metrics.PrecisionAtRecall(0.8) - pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] + pred_values = [0.0, 0.1, 0.2, 0.3, 0.5, 0.4, 0.5, 0.6, 0.8, 0.9] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + # For a score between 0.4 and 0.5, we expect 0.8 precision, 0.8 recall. y_pred = constant_op.constant(pred_values, dtype=dtypes.float32) y_true = constant_op.constant(label_values) self.evaluate(variables.variables_initializer(s_obj.variables)) result = s_obj(y_true, y_pred) - # For 0.5 < decision threshold < 0.6. - self.assertAlmostEqual(2.0/3, self.evaluate(result)) + self.assertAlmostEqual(0.8, self.evaluate(result)) def test_unweighted_low_recall(self): - s_obj = metrics.PrecisionAtRecall(0.6) - pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] + s_obj = metrics.PrecisionAtRecall(0.4) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.15, 0.25, 0.26, 0.26] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] y_pred = constant_op.constant(pred_values, dtype=dtypes.float32) y_true = constant_op.constant(label_values) self.evaluate(variables.variables_initializer(s_obj.variables)) result = s_obj(y_true, y_pred) - # For 0.2 < decision threshold < 0.5. - self.assertAlmostEqual(0.75, self.evaluate(result)) + self.assertAlmostEqual(0.5, self.evaluate(result)) @parameterized.parameters([dtypes.bool, dtypes.int32, dtypes.float32]) def test_weighted(self, label_dtype): - s_obj = metrics.PrecisionAtRecall(7.0/8) - pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] + s_obj = metrics.PrecisionAtRecall(0.4) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] - weight_values = [2, 1, 2, 1, 2, 1, 2, 2, 1, 2] + weight_values = [2, 2, 1, 1, 1, 1, 1, 2, 2, 2] y_pred = constant_op.constant(pred_values, dtype=dtypes.float32) y_true = math_ops.cast(label_values, dtype=label_dtype) weights = constant_op.constant(weight_values) self.evaluate(variables.variables_initializer(s_obj.variables)) result = s_obj(y_true, y_pred, sample_weight=weights) - # For 0.0 < decision threshold < 0.2. - self.assertAlmostEqual(0.7, self.evaluate(result)) + self.assertAlmostEqual(2./3., self.evaluate(result)) def test_invalid_sensitivity(self): with self.assertRaisesRegexp(