Add support for AUC and PR-AUC for binary classification model outputting logits.
PiperOrigin-RevId: 345437841 Change-Id: If00e810b0b6037c9c317c5ee143ad44084771a62
This commit is contained in:
parent
dbb41a3f72
commit
4e0f09d355
@ -22,6 +22,7 @@
|
||||
* Added `profile_data_directory` to `EmbeddingConfigSpec` in
|
||||
`_tpu_estimator_embedding.py`. This allows embedding lookup statistics
|
||||
gathered at runtime to be used in embedding layer partitioning decisions.
|
||||
* `tf.keras.metrics.AUC` now support logit predictions.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import math
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
@ -37,6 +38,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.activations import sigmoid
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
@ -1844,7 +1846,17 @@ class RecallAtPrecision(SensitivitySpecificityBase):
|
||||
|
||||
@keras_export('keras.metrics.AUC')
|
||||
class AUC(Metric):
|
||||
"""Computes the approximate AUC (Area under the curve) via a Riemann sum.
|
||||
"""Approximates the AUC (Area under the curve) of the ROC or PR curves.
|
||||
|
||||
The AUC (Area under the curve) of the ROC (Receiver operating
|
||||
characteristic; default) or PR (Precision Recall) curves are quality measures
|
||||
of binary classifiers. Unlike the accuracy, and like cross-entropy
|
||||
losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
|
||||
|
||||
This classes approximates AUCs using a Riemann sum: During the metric
|
||||
accumulation phrase, predictions are accumulated within predefined buckets
|
||||
by value. The AUC is then computed by interpolating per-bucket averages. These
|
||||
buckets define the evaluated operational points.
|
||||
|
||||
This metric creates four local variables, `true_positives`, `true_negatives`,
|
||||
`false_positives` and `false_negatives` that are used to compute the AUC.
|
||||
@ -1862,11 +1874,11 @@ class AUC(Metric):
|
||||
dramatically depending on `num_thresholds`. The `thresholds` parameter can be
|
||||
used to manually specify thresholds which split the predictions more evenly.
|
||||
|
||||
For best results, `predictions` should be distributed approximately uniformly
|
||||
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
|
||||
approximation may be poor if this is not the case. Setting `summation_method`
|
||||
to 'minoring' or 'majoring' can help quantify the error in the approximation
|
||||
by providing lower or upper bound estimate of the AUC.
|
||||
For a best approximation of the real AUC, `predictions` should be distributed
|
||||
approximately uniformly in the range [0, 1] (if `from_logits=False`). The
|
||||
quality of the AUC approximation may be poor if this is not the case. Setting
|
||||
`summation_method` to 'minoring' or 'majoring' can help quantify the error in
|
||||
the approximation by providing lower or upper bound estimate of the AUC.
|
||||
|
||||
If `sample_weight` is `None`, weights default to 1.
|
||||
Use `sample_weight` of 0 to mask values.
|
||||
@ -1912,6 +1924,10 @@ class AUC(Metric):
|
||||
label, whereas label_weights depends only on the index of that label
|
||||
before flattening; therefore `label_weights` should not be used for
|
||||
multi-class data.
|
||||
from_logits: boolean indicating whether the predictions (`y_pred` in
|
||||
`update_state`) are probabilities or sigmoid logits. As a rule of thumb,
|
||||
when using a keras loss, the `from_logits` constructor argument of the
|
||||
loss should match the AUC `from_logits` constructor argument.
|
||||
|
||||
Standalone usage:
|
||||
|
||||
@ -1933,7 +1949,15 @@ class AUC(Metric):
|
||||
Usage with `compile()` API:
|
||||
|
||||
```python
|
||||
model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.AUC()])
|
||||
# Reports the AUC of a model outputing a probability.
|
||||
model.compile(optimizer='sgd',
|
||||
loss=tf.keras.losses.BinaryCrossentropy(),
|
||||
metrics=[tf.keras.metrics.AUC()])
|
||||
|
||||
# Reports the AUC of a model outputing a logit.
|
||||
model.compile(optimizer='sgd',
|
||||
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
|
||||
metrics=[tf.keras.metrics.AUC(from_logits=True)])
|
||||
```
|
||||
"""
|
||||
|
||||
@ -1946,7 +1970,8 @@ class AUC(Metric):
|
||||
thresholds=None,
|
||||
multi_label=False,
|
||||
num_labels=None,
|
||||
label_weights=None):
|
||||
label_weights=None,
|
||||
from_logits=False):
|
||||
# Validate configurations.
|
||||
if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
|
||||
metrics_utils.AUCCurve):
|
||||
@ -2006,6 +2031,8 @@ class AUC(Metric):
|
||||
else:
|
||||
self.label_weights = None
|
||||
|
||||
self._from_logits = from_logits
|
||||
|
||||
self._built = False
|
||||
if self.multi_label:
|
||||
if num_labels:
|
||||
@ -2105,6 +2132,10 @@ class AUC(Metric):
|
||||
# multi_label is False. Otherwise the averaging of individual label AUCs is
|
||||
# handled in AUC.result
|
||||
label_weights = None if self.multi_label else self.label_weights
|
||||
|
||||
if self._from_logits:
|
||||
y_pred = sigmoid(y_pred)
|
||||
|
||||
with ops.control_dependencies(deps):
|
||||
return metrics_utils.update_confusion_matrix_variables(
|
||||
{
|
||||
|
@ -1142,6 +1142,8 @@ class AUCTest(test.TestCase, parameterized.TestCase):
|
||||
def setup(self):
|
||||
self.num_thresholds = 3
|
||||
self.y_pred = constant_op.constant([0, 0.5, 0.3, 0.9], dtype=dtypes.float32)
|
||||
epsilon = 1e-12
|
||||
self.y_pred_logits = -math_ops.log(1.0 / (self.y_pred + epsilon) - 1.0)
|
||||
self.y_true = constant_op.constant([0, 0, 1, 1])
|
||||
self.sample_weight = [1, 2, 3, 4]
|
||||
|
||||
@ -1264,6 +1266,20 @@ class AUCTest(test.TestCase, parameterized.TestCase):
|
||||
expected_result = (0.75 * 1 + 0.25 * 0)
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
|
||||
def test_unweighted_from_logits(self):
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, from_logits=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(self.y_true, self.y_pred_logits)
|
||||
|
||||
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
|
||||
# recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
|
||||
# fp_rate = [2/2, 0, 0] = [1, 0, 0]
|
||||
# heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
|
||||
# widths = [(1 - 0), (0 - 0)] = [1, 0]
|
||||
expected_result = (0.75 * 1 + 0.25 * 0)
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
|
||||
def test_manual_thresholds(self):
|
||||
self.setup()
|
||||
# Verify that when specified, thresholds are used instead of num_thresholds.
|
||||
@ -1418,6 +1434,10 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase):
|
||||
self.y_pred = constant_op.constant(
|
||||
np.array([[0, 0.5, 0.3, 0.9], [0.1, 0.2, 0.3, 0.4]]).T,
|
||||
dtype=dtypes.float32)
|
||||
|
||||
epsilon = 1e-12
|
||||
self.y_pred_logits = -math_ops.log(1.0 / (self.y_pred + epsilon) - 1.0)
|
||||
|
||||
self.y_true_good = constant_op.constant(
|
||||
np.array([[0, 0, 1, 1], [0, 0, 1, 1]]).T)
|
||||
self.y_true_bad = constant_op.constant(
|
||||
@ -1503,6 +1523,21 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase):
|
||||
expected_result = (0.875 + 1.0) / 2.0
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
|
||||
def test_unweighted_from_logits(self):
|
||||
with self.test_session():
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds,
|
||||
multi_label=True,
|
||||
from_logits=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(self.y_true_good, self.y_pred_logits)
|
||||
|
||||
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
expected_result = (0.875 + 1.0) / 2.0
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
|
||||
def test_sample_weight_flat(self):
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, multi_label=False)
|
||||
@ -1572,6 +1607,23 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase):
|
||||
expected_result = 1.0 - (3.0 / 32.0)
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
|
||||
def test_unweighted_flat_from_logits(self):
|
||||
self.setup()
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds, multi_label=False, from_logits=True)
|
||||
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||
result = auc_obj(self.y_true_good, self.y_pred_logits)
|
||||
|
||||
# tp = [4, 4, 1, 1, 0]
|
||||
# fp = [4, 1, 0, 0, 0]
|
||||
# fn = [0, 0, 3, 3, 4]
|
||||
# tn = [0, 3, 4, 4, 4]
|
||||
|
||||
# tpr = [1, 1, 0.25, 0.25, 0]
|
||||
# fpr = [1, 0.25, 0, 0, 0]
|
||||
expected_result = 1.0 - (3.0 / 32.0)
|
||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||
|
||||
def test_manual_thresholds(self):
|
||||
with self.test_session():
|
||||
self.setup()
|
||||
|
@ -2266,6 +2266,29 @@ class ResetStatesTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
|
||||
self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
|
||||
|
||||
def test_reset_states_auc_from_logits(self):
|
||||
auc_obj = metrics.AUC(num_thresholds=3, from_logits=True)
|
||||
|
||||
model_layers = [layers.Dense(1, kernel_initializer='ones', use_bias=False)]
|
||||
model = testing_utils.get_model_from_layers(model_layers, input_shape=(4,))
|
||||
model.compile(
|
||||
loss='mae',
|
||||
metrics=[auc_obj],
|
||||
optimizer='rmsprop',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
|
||||
x = np.concatenate((np.ones((25, 4)), -np.ones((25, 4)), -np.ones(
|
||||
(25, 4)), np.ones((25, 4))))
|
||||
y = np.concatenate((np.ones((25, 1)), np.zeros((25, 1)), np.ones(
|
||||
(25, 1)), np.zeros((25, 1))))
|
||||
|
||||
for _ in range(2):
|
||||
model.evaluate(x, y)
|
||||
self.assertEqual(self.evaluate(auc_obj.true_positives[1]), 25.)
|
||||
self.assertEqual(self.evaluate(auc_obj.false_positives[1]), 25.)
|
||||
self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
|
||||
self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
|
||||
|
||||
def test_reset_states_auc_manual_thresholds(self):
|
||||
auc_obj = metrics.AUC(thresholds=[0.5])
|
||||
model = _get_model([auc_obj])
|
||||
|
@ -134,7 +134,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -134,7 +134,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -134,7 +134,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
Loading…
Reference in New Issue
Block a user