From 4e0f09d355ab40f5b9084e5014faa2db1a3e9da3 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 3 Dec 2020 06:28:09 -0800
Subject: [PATCH] Add support for AUC and PR-AUC for binary classification
 model outputting logits.

PiperOrigin-RevId: 345437841
Change-Id: If00e810b0b6037c9c317c5ee143ad44084771a62
---
 RELEASE.md                                    |  1 +
 tensorflow/python/keras/metrics.py            | 47 ++++++++++++++---
 .../keras/metrics_confusion_matrix_test.py    | 52 +++++++++++++++++++
 tensorflow/python/keras/metrics_test.py       | 23 ++++++++
 .../v1/tensorflow.keras.metrics.-a-u-c.pbtxt  |  2 +-
 .../v2/tensorflow.keras.metrics.-a-u-c.pbtxt  |  2 +-
 .../golden/v2/tensorflow.metrics.-a-u-c.pbtxt |  2 +-
 7 files changed, 118 insertions(+), 11 deletions(-)

diff --git a/RELEASE.md b/RELEASE.md
index 22cd2c6f0bc..6657265ab64 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -22,6 +22,7 @@
   * Added `profile_data_directory` to `EmbeddingConfigSpec` in
     `_tpu_estimator_embedding.py`. This allows embedding lookup statistics
     gathered at runtime to be used in embedding layer partitioning decisions.
+* `tf.keras.metrics.AUC` now support logit predictions.
 
 ## Bug Fixes and Other Changes
 
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index b8df287cf5a..f05fb910a72 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -21,6 +21,7 @@ from __future__ import division
 from __future__ import print_function
 
 import abc
+import math
 import types
 
 import numpy as np
@@ -37,6 +38,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.activations import sigmoid
 from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.engine import keras_tensor
@@ -1844,7 +1846,17 @@ class RecallAtPrecision(SensitivitySpecificityBase):
 
 @keras_export('keras.metrics.AUC')
 class AUC(Metric):
-  """Computes the approximate AUC (Area under the curve) via a Riemann sum.
+  """Approximates the AUC (Area under the curve) of the ROC or PR curves.
+
+  The AUC (Area under the curve) of the ROC (Receiver operating
+  characteristic; default) or PR (Precision Recall) curves are quality measures
+  of binary classifiers. Unlike the accuracy, and like cross-entropy
+  losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
+
+  This classes approximates AUCs using a Riemann sum: During the metric
+  accumulation phrase, predictions are accumulated within predefined buckets
+  by value. The AUC is then computed by interpolating per-bucket averages. These
+  buckets define the evaluated operational points.
 
   This metric creates four local variables, `true_positives`, `true_negatives`,
   `false_positives` and `false_negatives` that are used to compute the AUC.
@@ -1862,11 +1874,11 @@ class AUC(Metric):
   dramatically depending on `num_thresholds`. The `thresholds` parameter can be
   used to manually specify thresholds which split the predictions more evenly.
 
-  For best results, `predictions` should be distributed approximately uniformly
-  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
-  approximation may be poor if this is not the case. Setting `summation_method`
-  to 'minoring' or 'majoring' can help quantify the error in the approximation
-  by providing lower or upper bound estimate of the AUC.
+  For a best approximation of the real AUC, `predictions` should be distributed
+  approximately uniformly in the range [0, 1] (if `from_logits=False`). The
+  quality of the AUC approximation may be poor if this is not the case. Setting
+  `summation_method` to 'minoring' or 'majoring' can help quantify the error in
+  the approximation by providing lower or upper bound estimate of the AUC.
 
   If `sample_weight` is `None`, weights default to 1.
   Use `sample_weight` of 0 to mask values.
@@ -1912,6 +1924,10 @@ class AUC(Metric):
       label, whereas label_weights depends only on the index of that label
       before flattening; therefore `label_weights` should not be used for
       multi-class data.
+    from_logits: boolean indicating whether the predictions (`y_pred` in
+      `update_state`) are probabilities or sigmoid logits. As a rule of thumb,
+      when using a keras loss, the `from_logits` constructor argument of the
+      loss should match the AUC `from_logits` constructor argument.
 
   Standalone usage:
 
@@ -1933,7 +1949,15 @@ class AUC(Metric):
   Usage with `compile()` API:
 
   ```python
-  model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.AUC()])
+  # Reports the AUC of a model outputing a probability.
+  model.compile(optimizer='sgd',
+                loss=tf.keras.losses.BinaryCrossentropy(),
+                metrics=[tf.keras.metrics.AUC()])
+
+  # Reports the AUC of a model outputing a logit.
+  model.compile(optimizer='sgd',
+                loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
+                metrics=[tf.keras.metrics.AUC(from_logits=True)])
   ```
   """
 
@@ -1946,7 +1970,8 @@ class AUC(Metric):
                thresholds=None,
                multi_label=False,
                num_labels=None,
-               label_weights=None):
+               label_weights=None,
+               from_logits=False):
     # Validate configurations.
     if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
         metrics_utils.AUCCurve):
@@ -2006,6 +2031,8 @@ class AUC(Metric):
     else:
       self.label_weights = None
 
+    self._from_logits = from_logits
+
     self._built = False
     if self.multi_label:
       if num_labels:
@@ -2105,6 +2132,10 @@ class AUC(Metric):
     # multi_label is False. Otherwise the averaging of individual label AUCs is
     # handled in AUC.result
     label_weights = None if self.multi_label else self.label_weights
+
+    if self._from_logits:
+      y_pred = sigmoid(y_pred)
+
     with ops.control_dependencies(deps):
       return metrics_utils.update_confusion_matrix_variables(
           {
diff --git a/tensorflow/python/keras/metrics_confusion_matrix_test.py b/tensorflow/python/keras/metrics_confusion_matrix_test.py
index 58c84557ec9..a3103daaa9e 100644
--- a/tensorflow/python/keras/metrics_confusion_matrix_test.py
+++ b/tensorflow/python/keras/metrics_confusion_matrix_test.py
@@ -1142,6 +1142,8 @@ class AUCTest(test.TestCase, parameterized.TestCase):
   def setup(self):
     self.num_thresholds = 3
     self.y_pred = constant_op.constant([0, 0.5, 0.3, 0.9], dtype=dtypes.float32)
+    epsilon = 1e-12
+    self.y_pred_logits = -math_ops.log(1.0 / (self.y_pred + epsilon) - 1.0)
     self.y_true = constant_op.constant([0, 0, 1, 1])
     self.sample_weight = [1, 2, 3, 4]
 
@@ -1264,6 +1266,20 @@ class AUCTest(test.TestCase, parameterized.TestCase):
     expected_result = (0.75 * 1 + 0.25 * 0)
     self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
 
+  def test_unweighted_from_logits(self):
+    self.setup()
+    auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, from_logits=True)
+    self.evaluate(variables.variables_initializer(auc_obj.variables))
+    result = auc_obj(self.y_true, self.y_pred_logits)
+
+    # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
+    # recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
+    # fp_rate = [2/2, 0, 0] = [1, 0, 0]
+    # heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
+    # widths = [(1 - 0), (0 - 0)] = [1, 0]
+    expected_result = (0.75 * 1 + 0.25 * 0)
+    self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
+
   def test_manual_thresholds(self):
     self.setup()
     # Verify that when specified, thresholds are used instead of num_thresholds.
@@ -1418,6 +1434,10 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase):
     self.y_pred = constant_op.constant(
         np.array([[0, 0.5, 0.3, 0.9], [0.1, 0.2, 0.3, 0.4]]).T,
         dtype=dtypes.float32)
+
+    epsilon = 1e-12
+    self.y_pred_logits = -math_ops.log(1.0 / (self.y_pred + epsilon) - 1.0)
+
     self.y_true_good = constant_op.constant(
         np.array([[0, 0, 1, 1], [0, 0, 1, 1]]).T)
     self.y_true_bad = constant_op.constant(
@@ -1503,6 +1523,21 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase):
       expected_result = (0.875 + 1.0) / 2.0
       self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
 
+  def test_unweighted_from_logits(self):
+    with self.test_session():
+      self.setup()
+      auc_obj = metrics.AUC(
+          num_thresholds=self.num_thresholds,
+          multi_label=True,
+          from_logits=True)
+      self.evaluate(variables.variables_initializer(auc_obj.variables))
+      result = auc_obj(self.y_true_good, self.y_pred_logits)
+
+      # tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
+      # fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
+      expected_result = (0.875 + 1.0) / 2.0
+      self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
+
   def test_sample_weight_flat(self):
     self.setup()
     auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, multi_label=False)
@@ -1572,6 +1607,23 @@ class MultiAUCTest(test.TestCase, parameterized.TestCase):
     expected_result = 1.0 - (3.0 / 32.0)
     self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
 
+  def test_unweighted_flat_from_logits(self):
+    self.setup()
+    auc_obj = metrics.AUC(
+        num_thresholds=self.num_thresholds, multi_label=False, from_logits=True)
+    self.evaluate(variables.variables_initializer(auc_obj.variables))
+    result = auc_obj(self.y_true_good, self.y_pred_logits)
+
+    # tp = [4, 4, 1, 1, 0]
+    # fp = [4, 1, 0, 0, 0]
+    # fn = [0, 0, 3, 3, 4]
+    # tn = [0, 3, 4, 4, 4]
+
+    # tpr = [1, 1, 0.25, 0.25, 0]
+    # fpr = [1, 0.25, 0, 0, 0]
+    expected_result = 1.0 - (3.0 / 32.0)
+    self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
+
   def test_manual_thresholds(self):
     with self.test_session():
       self.setup()
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index af3b9e2140b..afdf6658a8d 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -2266,6 +2266,29 @@ class ResetStatesTest(keras_parameterized.TestCase):
       self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
       self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
 
+  def test_reset_states_auc_from_logits(self):
+    auc_obj = metrics.AUC(num_thresholds=3, from_logits=True)
+
+    model_layers = [layers.Dense(1, kernel_initializer='ones', use_bias=False)]
+    model = testing_utils.get_model_from_layers(model_layers, input_shape=(4,))
+    model.compile(
+        loss='mae',
+        metrics=[auc_obj],
+        optimizer='rmsprop',
+        run_eagerly=testing_utils.should_run_eagerly())
+
+    x = np.concatenate((np.ones((25, 4)), -np.ones((25, 4)), -np.ones(
+        (25, 4)), np.ones((25, 4))))
+    y = np.concatenate((np.ones((25, 1)), np.zeros((25, 1)), np.ones(
+        (25, 1)), np.zeros((25, 1))))
+
+    for _ in range(2):
+      model.evaluate(x, y)
+      self.assertEqual(self.evaluate(auc_obj.true_positives[1]), 25.)
+      self.assertEqual(self.evaluate(auc_obj.false_positives[1]), 25.)
+      self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
+      self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
+
   def test_reset_states_auc_manual_thresholds(self):
     auc_obj = metrics.AUC(thresholds=[0.5])
     model = _get_model([auc_obj])
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt
index 294eaaab14d..fe32dbb739b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt
@@ -134,7 +134,7 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], "
   }
   member_method {
     name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt
index 294eaaab14d..fe32dbb739b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt
@@ -134,7 +134,7 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], "
   }
   member_method {
     name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt
index d7bd6b20982..d9fbb5dc97e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt
@@ -134,7 +134,7 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\'], "
   }
   member_method {
     name: "add_loss"