Add support for AUC and PR-AUC for binary classification model outputting logits.
PiperOrigin-RevId: 345437841 Change-Id: If00e810b0b6037c9c317c5ee143ad44084771a62
This commit is contained in:
parent
dbb41a3f72
commit
4e0f09d355
@ -22,6 +22,7 @@
|
|||||||
* Added `profile_data_directory` to `EmbeddingConfigSpec` in
|
* Added `profile_data_directory` to `EmbeddingConfigSpec` in
|
||||||
`_tpu_estimator_embedding.py`. This allows embedding lookup statistics
|
`_tpu_estimator_embedding.py`. This allows embedding lookup statistics
|
||||||
gathered at runtime to be used in embedding layer partitioning decisions.
|
gathered at runtime to be used in embedding layer partitioning decisions.
|
||||||
|
* `tf.keras.metrics.AUC` now support logit predictions.
|
||||||
|
|
||||||
## Bug Fixes and Other Changes
|
## Bug Fixes and Other Changes
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import math
|
||||||
import types
|
import types
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -37,6 +38,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
|
from tensorflow.python.keras.activations import sigmoid
|
||||||
from tensorflow.python.keras.engine import base_layer
|
from tensorflow.python.keras.engine import base_layer
|
||||||
from tensorflow.python.keras.engine import base_layer_utils
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
from tensorflow.python.keras.engine import keras_tensor
|
from tensorflow.python.keras.engine import keras_tensor
|
||||||
@ -1844,7 +1846,17 @@ class RecallAtPrecision(SensitivitySpecificityBase):
|
|||||||
|
|
||||||
@keras_export('keras.metrics.AUC')
|
@keras_export('keras.metrics.AUC')
|
||||||
class AUC(Metric):
|
class AUC(Metric):
|
||||||
"""Computes the approximate AUC (Area under the curve) via a Riemann sum.
|
"""Approximates the AUC (Area under the curve) of the ROC or PR curves.
|
||||||
|
|
||||||
|
The AUC (Area under the curve) of the ROC (Receiver operating
|
||||||
|
characteristic; default) or PR (Precision Recall) curves are quality measures
|
||||||
|
of binary classifiers. Unlike the accuracy, and like cross-entropy
|
||||||
|
losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
|
||||||
|
|
||||||
|
This classes approximates AUCs using a Riemann sum: During the metric
|
||||||
|
accumulation phrase, predictions are accumulated within predefined buckets
|
||||||
|
by value. The AUC is then computed by interpolating per-bucket averages. These
|
||||||
|
buckets define the evaluated operational points.
|
||||||
|
|
||||||
This metric creates four local variables, `true_positives`, `true_negatives`,
|
This metric creates four local variables, `true_positives`, `true_negatives`,
|
||||||
`false_positives` and `false_negatives` that are used to compute the AUC.
|
`false_positives` and `false_negatives` that are used to compute the AUC.
|
||||||
@ -1862,11 +1874,11 @@ class AUC(Metric):
|
|||||||
dramatically depending on `num_thresholds`. The `thresholds` parameter can be
|
dramatically depending on `num_thresholds`. The `thresholds` parameter can be
|
||||||
used to manually specify thresholds which split the predictions more evenly.
|
used to manually specify thresholds which split the predictions more evenly.
|
||||||
|
|
||||||
For best results, `predictions` should be distributed approximately uniformly
|
For a best approximation of the real AUC, `predictions` should be distributed
|
||||||
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
|
approximately uniformly in the range [0, 1] (if `from_logits=False`). The
|
||||||
approximation may be poor if this is not the case. Setting `summation_method`
|
quality of the AUC approximation may be poor if this is not the case. Setting
|
||||||
to 'minoring' or 'majoring' can help quantify the error in the approximation
|
`summation_method` to 'minoring' or 'majoring' can help quantify the error in
|
||||||
by providing lower or upper bound estimate of the AUC.
|
the approximation by providing lower or upper bound estimate of the AUC.
|
||||||
|
|
||||||
If `sample_weight` is `None`, weights default to 1.
|
If `sample_weight` is `None`, weights default to 1.
|
||||||
Use `sample_weight` of 0 to mask values.
|
Use `sample_weight` of 0 to mask values.
|
||||||
@ -1912,6 +1924,10 @@ class AUC(Metric):
|
|||||||
label, whereas label_weights depends only on the index of that label
|
label, whereas label_weights depends only on the index of that label
|
||||||
before flattening; therefore `label_weights` should not be used for
|
before flattening; therefore `label_weights` should not be used for
|
||||||
multi-class data.
|
multi-class data.
|
||||||
|
from_logits: boolean indicating whether the predictions (`y_pred` in
|
||||||
|
`update_state`) are probabilities or sigmoid logits. As a rule of thumb,
|
||||||
|
when using a keras loss, the `from_logits` constructor argument of the
|
||||||
|
loss should match the AUC `from_logits` constructor argument.
|
||||||
|
|
||||||
Standalone usage:
|
Standalone usage:
|
||||||
|
|
||||||
@ -1933,7 +1949,15 @@ class AUC(Metric):
|
|||||||
Usage with `compile()` API:
|
Usage with `compile()` API:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.AUC()])
|
# Reports the AUC of a model outputing a probability.
|
||||||
|
model.compile(optimizer='sgd',
|
||||||
|
loss=tf.keras.losses.BinaryCrossentropy(),
|
||||||
|
metrics=[tf.keras.metrics.AUC()])
|
||||||
|
|
||||||
|
# Reports the AUC of a model outputing a logit.
|
||||||
|
model.compile(optimizer='sgd',
|
||||||
|
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
|
||||||
|
metrics=[tf.keras.metrics.AUC(from_logits=True)])
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -1946,7 +1970,8 @@ class AUC(Metric):
|
|||||||
thresholds=None,
|
thresholds=None,
|
||||||
multi_label=False,
|
multi_label=False,
|
||||||
num_labels=None,
|
num_labels=None,
|
||||||
label_weights=None):
|
label_weights=None,
|
||||||
|
from_logits=False):
|
||||||
# Validate configurations.
|
# Validate configurations.
|
||||||
if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
|
if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
|
||||||
metrics_utils.AUCCurve):
|
metrics_utils.AUCCurve):
|
||||||
@ -2006,6 +2031,8 @@ class AUC(Metric):
|
|||||||
else:
|
else:
|
||||||
self.label_weights = None
|
self.label_weights = None
|
||||||
|
|
||||||
|
self._from_logits = from_logits
|
||||||
|
|
||||||
self._built = False
|
self._built = False
|
||||||
if self.multi_label:
|
if self.multi_label:
|
||||||
if num_labels:
|
if num_labels:
|
||||||
@ -2105,6 +2132,10 @@ class AUC(Metric):
|
|||||||
# multi_label is False. Otherwise the averaging of individual label AUCs is
|
# multi_label is False. Otherwise the averaging of individual label AUCs is
|
||||||
# handled in AUC.result
|
# handled in AUC.result
|
||||||
label_weights = None if self.multi_label else self.label_weights
|
label_weights = None if self.multi_label else self.label_weights
|
||||||
|
|
||||||
|
if self._from_logits:
|
||||||
|
y_pred = sigmoid(y_pred)
|
||||||
|
|
||||||
with ops.control_dependencies(deps):
|
with ops.control_dependencies(deps):
|
||||||
return metrics_utils.update_confusion_matrix_variables(
|
return metrics_utils.update_confusion_matrix_variables(
|
||||||
{
|
{
|
||||||
|
|||||||
@ -1142,6 +1142,8 @@ class AUCTest(test.TestCase, parameterized.TestCase):
|
|||||||
def setup(self):
|
def setup(self):
|
||||||
self.num_thresholds = 3
|
self.num_thresholds = 3
|
||||||
self.y_pred = constant_op.constant([0, 0.5, 0.3, 0.9], dtype=dtypes.float32)
|
self.y_pred = constant_op.constant([0, 0.5, 0.3, 0.9], dtype=dtypes.float32)
|
||||||
|
epsilon = 1e-12
|
||||||
|
self.y_pred_logits = -math_ops.log(1.0 / (self.y_pred + epsilon) - 1.0)
|
||||||
self.y_true = constant_op.constant([0, 0, 1, 1])
|
self.y_true = constant_op.constant([0, 0, 1, 1])
|
||||||
self.sample_weight = [1, 2, 3, 4]
|
self.sample_weight = [1, 2, 3, 4]
|
||||||
|
|
||||||
@ -1264,6 +1266,20 @@ class AUCTest(test.TestCase, parameterized.TestCase):
|
|||||||
expected_result = (0.75 * 1 + 0.25 * 0)
|
expected_result = (0.75 * 1 + 0.25 * 0)
|
||||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||||
|
|
||||||
|
def test_unweighted_from_logits(self):
|
||||||
|
self.setup()
|
||||||
|
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, from_logits=True)
|
||||||
|
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||||
|
result = auc_obj(self.y_true, self.y_pred_logits)
|
||||||
|
|
||||||
|
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
|
||||||
|
# recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
|
||||||
|
# fp_rate = [2/2, 0, 0] = [1, 0, 0]
|
||||||
|
# heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
|
||||||
|
# widths = [(1 - 0), (0 - 0)] = [1, 0]
|
||||||
|
expected_result = (0.75 * 1 + 0.25 * 0)
|
||||||
|
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||||
|
|
||||||
def test_manual_thresholds(self):
|
def test_manual_thresholds(self):
|
||||||
self.setup()
|
self.setup()
|
||||||
# Verify that when specified, thresholds are used instead of num_thresholds.
|
# Verify that when specified, thresholds are used instead of num_thresholds.
|
||||||
@ -1418,6 +1434,10 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.y_pred = constant_op.constant(
|
self.y_pred = constant_op.constant(
|
||||||
np.array([[0, 0.5, 0.3, 0.9], [0.1, 0.2, 0.3, 0.4]]).T,
|
np.array([[0, 0.5, 0.3, 0.9], [0.1, 0.2, 0.3, 0.4]]).T,
|
||||||
dtype=dtypes.float32)
|
dtype=dtypes.float32)
|
||||||
|
|
||||||
|
epsilon = 1e-12
|
||||||
|
self.y_pred_logits = -math_ops.log(1.0 / (self.y_pred + epsilon) - 1.0)
|
||||||
|
|
||||||
self.y_true_good = constant_op.constant(
|
self.y_true_good = constant_op.constant(
|
||||||
np.array([[0, 0, 1, 1], [0, 0, 1, 1]]).T)
|
np.array([[0, 0, 1, 1], [0, 0, 1, 1]]).T)
|
||||||
self.y_true_bad = constant_op.constant(
|
self.y_true_bad = constant_op.constant(
|
||||||
@ -1503,6 +1523,21 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase):
|
|||||||
expected_result = (0.875 + 1.0) / 2.0
|
expected_result = (0.875 + 1.0) / 2.0
|
||||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||||
|
|
||||||
|
def test_unweighted_from_logits(self):
|
||||||
|
with self.test_session():
|
||||||
|
self.setup()
|
||||||
|
auc_obj = metrics.AUC(
|
||||||
|
num_thresholds=self.num_thresholds,
|
||||||
|
multi_label=True,
|
||||||
|
from_logits=True)
|
||||||
|
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||||
|
result = auc_obj(self.y_true_good, self.y_pred_logits)
|
||||||
|
|
||||||
|
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
|
||||||
|
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||||
|
expected_result = (0.875 + 1.0) / 2.0
|
||||||
|
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||||
|
|
||||||
def test_sample_weight_flat(self):
|
def test_sample_weight_flat(self):
|
||||||
self.setup()
|
self.setup()
|
||||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, multi_label=False)
|
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, multi_label=False)
|
||||||
@ -1572,6 +1607,23 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase):
|
|||||||
expected_result = 1.0 - (3.0 / 32.0)
|
expected_result = 1.0 - (3.0 / 32.0)
|
||||||
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||||
|
|
||||||
|
def test_unweighted_flat_from_logits(self):
|
||||||
|
self.setup()
|
||||||
|
auc_obj = metrics.AUC(
|
||||||
|
num_thresholds=self.num_thresholds, multi_label=False, from_logits=True)
|
||||||
|
self.evaluate(variables.variables_initializer(auc_obj.variables))
|
||||||
|
result = auc_obj(self.y_true_good, self.y_pred_logits)
|
||||||
|
|
||||||
|
# tp = [4, 4, 1, 1, 0]
|
||||||
|
# fp = [4, 1, 0, 0, 0]
|
||||||
|
# fn = [0, 0, 3, 3, 4]
|
||||||
|
# tn = [0, 3, 4, 4, 4]
|
||||||
|
|
||||||
|
# tpr = [1, 1, 0.25, 0.25, 0]
|
||||||
|
# fpr = [1, 0.25, 0, 0, 0]
|
||||||
|
expected_result = 1.0 - (3.0 / 32.0)
|
||||||
|
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
|
||||||
|
|
||||||
def test_manual_thresholds(self):
|
def test_manual_thresholds(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
self.setup()
|
self.setup()
|
||||||
|
|||||||
@ -2266,6 +2266,29 @@ class ResetStatesTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
|
self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
|
||||||
self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
|
self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
|
||||||
|
|
||||||
|
def test_reset_states_auc_from_logits(self):
|
||||||
|
auc_obj = metrics.AUC(num_thresholds=3, from_logits=True)
|
||||||
|
|
||||||
|
model_layers = [layers.Dense(1, kernel_initializer='ones', use_bias=False)]
|
||||||
|
model = testing_utils.get_model_from_layers(model_layers, input_shape=(4,))
|
||||||
|
model.compile(
|
||||||
|
loss='mae',
|
||||||
|
metrics=[auc_obj],
|
||||||
|
optimizer='rmsprop',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
|
||||||
|
x = np.concatenate((np.ones((25, 4)), -np.ones((25, 4)), -np.ones(
|
||||||
|
(25, 4)), np.ones((25, 4))))
|
||||||
|
y = np.concatenate((np.ones((25, 1)), np.zeros((25, 1)), np.ones(
|
||||||
|
(25, 1)), np.zeros((25, 1))))
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
model.evaluate(x, y)
|
||||||
|
self.assertEqual(self.evaluate(auc_obj.true_positives[1]), 25.)
|
||||||
|
self.assertEqual(self.evaluate(auc_obj.false_positives[1]), 25.)
|
||||||
|
self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
|
||||||
|
self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
|
||||||
|
|
||||||
def test_reset_states_auc_manual_thresholds(self):
|
def test_reset_states_auc_manual_thresholds(self):
|
||||||
auc_obj = metrics.AUC(thresholds=[0.5])
|
auc_obj = metrics.AUC(thresholds=[0.5])
|
||||||
model = _get_model([auc_obj])
|
model = _get_model([auc_obj])
|
||||||
|
|||||||
@ -134,7 +134,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_loss"
|
name: "add_loss"
|
||||||
|
|||||||
@ -134,7 +134,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_loss"
|
name: "add_loss"
|
||||||
|
|||||||
@ -134,7 +134,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_loss"
|
name: "add_loss"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user