diff --git a/RELEASE.md b/RELEASE.md index 22cd2c6f0bc..6657265ab64 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -22,6 +22,7 @@ * Added `profile_data_directory` to `EmbeddingConfigSpec` in `_tpu_estimator_embedding.py`. This allows embedding lookup statistics gathered at runtime to be used in embedding layer partitioning decisions. +* `tf.keras.metrics.AUC` now support logit predictions. ## Bug Fixes and Other Changes diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index b8df287cf5a..f05fb910a72 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -21,6 +21,7 @@ from __future__ import division from __future__ import print_function import abc +import math import types import numpy as np @@ -37,6 +38,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K +from tensorflow.python.keras.activations import sigmoid from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import keras_tensor @@ -1844,7 +1846,17 @@ class RecallAtPrecision(SensitivitySpecificityBase): @keras_export('keras.metrics.AUC') class AUC(Metric): - """Computes the approximate AUC (Area under the curve) via a Riemann sum. + """Approximates the AUC (Area under the curve) of the ROC or PR curves. + + The AUC (Area under the curve) of the ROC (Receiver operating + characteristic; default) or PR (Precision Recall) curves are quality measures + of binary classifiers. Unlike the accuracy, and like cross-entropy + losses, ROC-AUC and PR-AUC evaluate all the operational points of a model. + + This classes approximates AUCs using a Riemann sum: During the metric + accumulation phrase, predictions are accumulated within predefined buckets + by value. The AUC is then computed by interpolating per-bucket averages. These + buckets define the evaluated operational points. This metric creates four local variables, `true_positives`, `true_negatives`, `false_positives` and `false_negatives` that are used to compute the AUC. @@ -1862,11 +1874,11 @@ class AUC(Metric): dramatically depending on `num_thresholds`. The `thresholds` parameter can be used to manually specify thresholds which split the predictions more evenly. - For best results, `predictions` should be distributed approximately uniformly - in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC - approximation may be poor if this is not the case. Setting `summation_method` - to 'minoring' or 'majoring' can help quantify the error in the approximation - by providing lower or upper bound estimate of the AUC. + For a best approximation of the real AUC, `predictions` should be distributed + approximately uniformly in the range [0, 1] (if `from_logits=False`). The + quality of the AUC approximation may be poor if this is not the case. Setting + `summation_method` to 'minoring' or 'majoring' can help quantify the error in + the approximation by providing lower or upper bound estimate of the AUC. If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values. @@ -1912,6 +1924,10 @@ class AUC(Metric): label, whereas label_weights depends only on the index of that label before flattening; therefore `label_weights` should not be used for multi-class data. + from_logits: boolean indicating whether the predictions (`y_pred` in + `update_state`) are probabilities or sigmoid logits. As a rule of thumb, + when using a keras loss, the `from_logits` constructor argument of the + loss should match the AUC `from_logits` constructor argument. Standalone usage: @@ -1933,7 +1949,15 @@ class AUC(Metric): Usage with `compile()` API: ```python - model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.AUC()]) + # Reports the AUC of a model outputing a probability. + model.compile(optimizer='sgd', + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=[tf.keras.metrics.AUC()]) + + # Reports the AUC of a model outputing a logit. + model.compile(optimizer='sgd', + loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), + metrics=[tf.keras.metrics.AUC(from_logits=True)]) ``` """ @@ -1946,7 +1970,8 @@ class AUC(Metric): thresholds=None, multi_label=False, num_labels=None, - label_weights=None): + label_weights=None, + from_logits=False): # Validate configurations. if isinstance(curve, metrics_utils.AUCCurve) and curve not in list( metrics_utils.AUCCurve): @@ -2006,6 +2031,8 @@ class AUC(Metric): else: self.label_weights = None + self._from_logits = from_logits + self._built = False if self.multi_label: if num_labels: @@ -2105,6 +2132,10 @@ class AUC(Metric): # multi_label is False. Otherwise the averaging of individual label AUCs is # handled in AUC.result label_weights = None if self.multi_label else self.label_weights + + if self._from_logits: + y_pred = sigmoid(y_pred) + with ops.control_dependencies(deps): return metrics_utils.update_confusion_matrix_variables( { diff --git a/tensorflow/python/keras/metrics_confusion_matrix_test.py b/tensorflow/python/keras/metrics_confusion_matrix_test.py index 58c84557ec9..a3103daaa9e 100644 --- a/tensorflow/python/keras/metrics_confusion_matrix_test.py +++ b/tensorflow/python/keras/metrics_confusion_matrix_test.py @@ -1142,6 +1142,8 @@ class AUCTest(test.TestCase, parameterized.TestCase): def setup(self): self.num_thresholds = 3 self.y_pred = constant_op.constant([0, 0.5, 0.3, 0.9], dtype=dtypes.float32) + epsilon = 1e-12 + self.y_pred_logits = -math_ops.log(1.0 / (self.y_pred + epsilon) - 1.0) self.y_true = constant_op.constant([0, 0, 1, 1]) self.sample_weight = [1, 2, 3, 4] @@ -1264,6 +1266,20 @@ class AUCTest(test.TestCase, parameterized.TestCase): expected_result = (0.75 * 1 + 0.25 * 0) self.assertAllClose(self.evaluate(result), expected_result, 1e-3) + def test_unweighted_from_logits(self): + self.setup() + auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, from_logits=True) + self.evaluate(variables.variables_initializer(auc_obj.variables)) + result = auc_obj(self.y_true, self.y_pred_logits) + + # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] + # recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0] + # fp_rate = [2/2, 0, 0] = [1, 0, 0] + # heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = (0.75 * 1 + 0.25 * 0) + self.assertAllClose(self.evaluate(result), expected_result, 1e-3) + def test_manual_thresholds(self): self.setup() # Verify that when specified, thresholds are used instead of num_thresholds. @@ -1418,6 +1434,10 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase): self.y_pred = constant_op.constant( np.array([[0, 0.5, 0.3, 0.9], [0.1, 0.2, 0.3, 0.4]]).T, dtype=dtypes.float32) + + epsilon = 1e-12 + self.y_pred_logits = -math_ops.log(1.0 / (self.y_pred + epsilon) - 1.0) + self.y_true_good = constant_op.constant( np.array([[0, 0, 1, 1], [0, 0, 1, 1]]).T) self.y_true_bad = constant_op.constant( @@ -1503,6 +1523,21 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase): expected_result = (0.875 + 1.0) / 2.0 self.assertAllClose(self.evaluate(result), expected_result, 1e-3) + def test_unweighted_from_logits(self): + with self.test_session(): + self.setup() + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, + multi_label=True, + from_logits=True) + self.evaluate(variables.variables_initializer(auc_obj.variables)) + result = auc_obj(self.y_true_good, self.y_pred_logits) + + # tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]] + # fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]] + expected_result = (0.875 + 1.0) / 2.0 + self.assertAllClose(self.evaluate(result), expected_result, 1e-3) + def test_sample_weight_flat(self): self.setup() auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, multi_label=False) @@ -1572,6 +1607,23 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase): expected_result = 1.0 - (3.0 / 32.0) self.assertAllClose(self.evaluate(result), expected_result, 1e-3) + def test_unweighted_flat_from_logits(self): + self.setup() + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, multi_label=False, from_logits=True) + self.evaluate(variables.variables_initializer(auc_obj.variables)) + result = auc_obj(self.y_true_good, self.y_pred_logits) + + # tp = [4, 4, 1, 1, 0] + # fp = [4, 1, 0, 0, 0] + # fn = [0, 0, 3, 3, 4] + # tn = [0, 3, 4, 4, 4] + + # tpr = [1, 1, 0.25, 0.25, 0] + # fpr = [1, 0.25, 0, 0, 0] + expected_result = 1.0 - (3.0 / 32.0) + self.assertAllClose(self.evaluate(result), expected_result, 1e-3) + def test_manual_thresholds(self): with self.test_session(): self.setup() diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index af3b9e2140b..afdf6658a8d 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -2266,6 +2266,29 @@ class ResetStatesTest(keras_parameterized.TestCase): self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.) self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.) + def test_reset_states_auc_from_logits(self): + auc_obj = metrics.AUC(num_thresholds=3, from_logits=True) + + model_layers = [layers.Dense(1, kernel_initializer='ones', use_bias=False)] + model = testing_utils.get_model_from_layers(model_layers, input_shape=(4,)) + model.compile( + loss='mae', + metrics=[auc_obj], + optimizer='rmsprop', + run_eagerly=testing_utils.should_run_eagerly()) + + x = np.concatenate((np.ones((25, 4)), -np.ones((25, 4)), -np.ones( + (25, 4)), np.ones((25, 4)))) + y = np.concatenate((np.ones((25, 1)), np.zeros((25, 1)), np.ones( + (25, 1)), np.zeros((25, 1)))) + + for _ in range(2): + model.evaluate(x, y) + self.assertEqual(self.evaluate(auc_obj.true_positives[1]), 25.) + self.assertEqual(self.evaluate(auc_obj.false_positives[1]), 25.) + self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.) + self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.) + def test_reset_states_auc_manual_thresholds(self): auc_obj = metrics.AUC(thresholds=[0.5]) model = _get_model([auc_obj]) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt index 294eaaab14d..fe32dbb739b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt @@ -134,7 +134,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt index 294eaaab14d..fe32dbb739b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt @@ -134,7 +134,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt index d7bd6b20982..d9fbb5dc97e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt @@ -134,7 +134,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], " } member_method { name: "add_loss"