Add support for AUC and PR-AUC for binary classification model outputting logits.

PiperOrigin-RevId: 345437841
Change-Id: If00e810b0b6037c9c317c5ee143ad44084771a62
This commit is contained in:
A. Unique TensorFlower 2020-12-03 06:28:09 -08:00 committed by TensorFlower Gardener
parent dbb41a3f72
commit 4e0f09d355
7 changed files with 118 additions and 11 deletions

View File

@ -22,6 +22,7 @@
* Added `profile_data_directory` to `EmbeddingConfigSpec` in
`_tpu_estimator_embedding.py`. This allows embedding lookup statistics
gathered at runtime to be used in embedding layer partitioning decisions.
* `tf.keras.metrics.AUC` now support logit predictions.
## Bug Fixes and Other Changes

View File

@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function
import abc
import math
import types
import numpy as np
@ -37,6 +38,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.activations import sigmoid
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import keras_tensor
@ -1844,7 +1846,17 @@ class RecallAtPrecision(SensitivitySpecificityBase):
@keras_export('keras.metrics.AUC')
class AUC(Metric):
"""Computes the approximate AUC (Area under the curve) via a Riemann sum.
"""Approximates the AUC (Area under the curve) of the ROC or PR curves.
The AUC (Area under the curve) of the ROC (Receiver operating
characteristic; default) or PR (Precision Recall) curves are quality measures
of binary classifiers. Unlike the accuracy, and like cross-entropy
losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
This classes approximates AUCs using a Riemann sum: During the metric
accumulation phrase, predictions are accumulated within predefined buckets
by value. The AUC is then computed by interpolating per-bucket averages. These
buckets define the evaluated operational points.
This metric creates four local variables, `true_positives`, `true_negatives`,
`false_positives` and `false_negatives` that are used to compute the AUC.
@ -1862,11 +1874,11 @@ class AUC(Metric):
dramatically depending on `num_thresholds`. The `thresholds` parameter can be
used to manually specify thresholds which split the predictions more evenly.
For best results, `predictions` should be distributed approximately uniformly
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
approximation may be poor if this is not the case. Setting `summation_method`
to 'minoring' or 'majoring' can help quantify the error in the approximation
by providing lower or upper bound estimate of the AUC.
For a best approximation of the real AUC, `predictions` should be distributed
approximately uniformly in the range [0, 1] (if `from_logits=False`). The
quality of the AUC approximation may be poor if this is not the case. Setting
`summation_method` to 'minoring' or 'majoring' can help quantify the error in
the approximation by providing lower or upper bound estimate of the AUC.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
@ -1912,6 +1924,10 @@ class AUC(Metric):
label, whereas label_weights depends only on the index of that label
before flattening; therefore `label_weights` should not be used for
multi-class data.
from_logits: boolean indicating whether the predictions (`y_pred` in
`update_state`) are probabilities or sigmoid logits. As a rule of thumb,
when using a keras loss, the `from_logits` constructor argument of the
loss should match the AUC `from_logits` constructor argument.
Standalone usage:
@ -1933,7 +1949,15 @@ class AUC(Metric):
Usage with `compile()` API:
```python
model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.AUC()])
# Reports the AUC of a model outputing a probability.
model.compile(optimizer='sgd',
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.AUC()])
# Reports the AUC of a model outputing a logit.
model.compile(optimizer='sgd',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.AUC(from_logits=True)])
```
"""
@ -1946,7 +1970,8 @@ class AUC(Metric):
thresholds=None,
multi_label=False,
num_labels=None,
label_weights=None):
label_weights=None,
from_logits=False):
# Validate configurations.
if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
metrics_utils.AUCCurve):
@ -2006,6 +2031,8 @@ class AUC(Metric):
else:
self.label_weights = None
self._from_logits = from_logits
self._built = False
if self.multi_label:
if num_labels:
@ -2105,6 +2132,10 @@ class AUC(Metric):
# multi_label is False. Otherwise the averaging of individual label AUCs is
# handled in AUC.result
label_weights = None if self.multi_label else self.label_weights
if self._from_logits:
y_pred = sigmoid(y_pred)
with ops.control_dependencies(deps):
return metrics_utils.update_confusion_matrix_variables(
{

View File

@ -1142,6 +1142,8 @@ class AUCTest(test.TestCase, parameterized.TestCase):
def setup(self):
self.num_thresholds = 3
self.y_pred = constant_op.constant([0, 0.5, 0.3, 0.9], dtype=dtypes.float32)
epsilon = 1e-12
self.y_pred_logits = -math_ops.log(1.0 / (self.y_pred + epsilon) - 1.0)
self.y_true = constant_op.constant([0, 0, 1, 1])
self.sample_weight = [1, 2, 3, 4]
@ -1264,6 +1266,20 @@ class AUCTest(test.TestCase, parameterized.TestCase):
expected_result = (0.75 * 1 + 0.25 * 0)
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_unweighted_from_logits(self):
self.setup()
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, from_logits=True)
self.evaluate(variables.variables_initializer(auc_obj.variables))
result = auc_obj(self.y_true, self.y_pred_logits)
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
# recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
# fp_rate = [2/2, 0, 0] = [1, 0, 0]
# heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
# widths = [(1 - 0), (0 - 0)] = [1, 0]
expected_result = (0.75 * 1 + 0.25 * 0)
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_manual_thresholds(self):
self.setup()
# Verify that when specified, thresholds are used instead of num_thresholds.
@ -1418,6 +1434,10 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase):
self.y_pred = constant_op.constant(
np.array([[0, 0.5, 0.3, 0.9], [0.1, 0.2, 0.3, 0.4]]).T,
dtype=dtypes.float32)
epsilon = 1e-12
self.y_pred_logits = -math_ops.log(1.0 / (self.y_pred + epsilon) - 1.0)
self.y_true_good = constant_op.constant(
np.array([[0, 0, 1, 1], [0, 0, 1, 1]]).T)
self.y_true_bad = constant_op.constant(
@ -1503,6 +1523,21 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase):
expected_result = (0.875 + 1.0) / 2.0
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_unweighted_from_logits(self):
with self.test_session():
self.setup()
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds,
multi_label=True,
from_logits=True)
self.evaluate(variables.variables_initializer(auc_obj.variables))
result = auc_obj(self.y_true_good, self.y_pred_logits)
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
expected_result = (0.875 + 1.0) / 2.0
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_sample_weight_flat(self):
self.setup()
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, multi_label=False)
@ -1572,6 +1607,23 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase):
expected_result = 1.0 - (3.0 / 32.0)
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_unweighted_flat_from_logits(self):
self.setup()
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds, multi_label=False, from_logits=True)
self.evaluate(variables.variables_initializer(auc_obj.variables))
result = auc_obj(self.y_true_good, self.y_pred_logits)
# tp = [4, 4, 1, 1, 0]
# fp = [4, 1, 0, 0, 0]
# fn = [0, 0, 3, 3, 4]
# tn = [0, 3, 4, 4, 4]
# tpr = [1, 1, 0.25, 0.25, 0]
# fpr = [1, 0.25, 0, 0, 0]
expected_result = 1.0 - (3.0 / 32.0)
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_manual_thresholds(self):
with self.test_session():
self.setup()

View File

@ -2266,6 +2266,29 @@ class ResetStatesTest(keras_parameterized.TestCase):
self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
def test_reset_states_auc_from_logits(self):
auc_obj = metrics.AUC(num_thresholds=3, from_logits=True)
model_layers = [layers.Dense(1, kernel_initializer='ones', use_bias=False)]
model = testing_utils.get_model_from_layers(model_layers, input_shape=(4,))
model.compile(
loss='mae',
metrics=[auc_obj],
optimizer='rmsprop',
run_eagerly=testing_utils.should_run_eagerly())
x = np.concatenate((np.ones((25, 4)), -np.ones((25, 4)), -np.ones(
(25, 4)), np.ones((25, 4))))
y = np.concatenate((np.ones((25, 1)), np.zeros((25, 1)), np.ones(
(25, 1)), np.zeros((25, 1))))
for _ in range(2):
model.evaluate(x, y)
self.assertEqual(self.evaluate(auc_obj.true_positives[1]), 25.)
self.assertEqual(self.evaluate(auc_obj.false_positives[1]), 25.)
self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
def test_reset_states_auc_manual_thresholds(self):
auc_obj = metrics.AUC(thresholds=[0.5])
model = _get_model([auc_obj])

View File

@ -134,7 +134,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], "
}
member_method {
name: "add_loss"

View File

@ -134,7 +134,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], "
}
member_method {
name: "add_loss"

View File

@ -134,7 +134,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], "
}
member_method {
name: "add_loss"