diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 02429706505..15230c4fd4d 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -1570,7 +1570,8 @@ class AUC(Metric): (computed using the aforementioned variables). The `num_thresholds` variable controls the degree of discretization with larger numbers of thresholds more closely approximating the true AUC. The quality of the approximation may vary - dramatically depending on `num_thresholds`. + dramatically depending on `num_thresholds`. The `thresholds` parameter can be + used to manually specify thresholds which split the predictions more evenly. For best results, `predictions` should be distributed approximately uniformly in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC @@ -1608,7 +1609,8 @@ class AUC(Metric): curve='ROC', summation_method='interpolation', name=None, - dtype=None): + dtype=None, + thresholds=None): """Creates an `AUC` instance. Args: @@ -1625,10 +1627,14 @@ class AUC(Metric): 'majoring' that does the opposite. name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. + thresholds: (Optional) A list of floating point values to use as the + thresholds for discretizing the curve. If set, the `num_thresholds` + parameter is ignored. Values should be in [0, 1]. Endpoint thresholds + equal to {-epsilon, 1+epsilon} for a small positive epsilon value will + be automatically included with these to correctly handle predictions + equal to exactly 0 or 1. """ # Validate configurations. - if num_thresholds <= 1: - raise ValueError('`num_thresholds` must be > 1.') if isinstance(curve, metrics_utils.AUCCurve) and curve not in list( metrics_utils.AUCCurve): raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format( @@ -1642,7 +1648,24 @@ class AUC(Metric): summation_method, list(metrics_utils.AUCSummationMethod))) # Update properties. - self.num_thresholds = num_thresholds + if thresholds is not None: + # If specified, use the supplied thresholds. + self.num_thresholds = len(thresholds) + 2 + thresholds = sorted(thresholds) + else: + if num_thresholds <= 1: + raise ValueError('`num_thresholds` must be > 1.') + + # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in + # (0, 1). + self.num_thresholds = num_thresholds + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + + # Add an endpoint "threshold" below zero and above one for either + # threshold method to account for floating point imprecisions. + self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()] + if isinstance(curve, metrics_utils.AUCCurve): self.curve = curve else: @@ -1657,28 +1680,21 @@ class AUC(Metric): # Create metric variables self.true_positives = self.add_weight( 'true_positives', - shape=(num_thresholds,), + shape=(self.num_thresholds,), initializer=init_ops.zeros_initializer) self.true_negatives = self.add_weight( 'true_negatives', - shape=(num_thresholds,), + shape=(self.num_thresholds,), initializer=init_ops.zeros_initializer) self.false_positives = self.add_weight( 'false_positives', - shape=(num_thresholds,), + shape=(self.num_thresholds,), initializer=init_ops.zeros_initializer) self.false_negatives = self.add_weight( 'false_negatives', - shape=(num_thresholds,), + shape=(self.num_thresholds,), initializer=init_ops.zeros_initializer) - # Compute `num_thresholds` thresholds in [0, 1] - thresholds = [ - (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) - ] - self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()] - # epsilon - to account for floating point imprecisions. - def update_state(self, y_true, y_pred, sample_weight=None): """Accumulates confusion matrix statistics. @@ -1804,15 +1820,18 @@ class AUC(Metric): name=self.name) def reset_states(self): - num_thresholds = len(self.thresholds) K.batch_set_value( - [(v, np.zeros((num_thresholds,))) for v in self.variables]) + [(v, np.zeros((self.num_thresholds,))) for v in self.variables]) def get_config(self): config = { 'num_thresholds': self.num_thresholds, 'curve': self.curve.value, 'summation_method': self.summation_method.value, + # We remove the endpoint thresholds as an inverse of how the thresholds + # were initialized. This ensures that a metric initialized from this + # config has the same thresholds. + 'thresholds': self.thresholds[1:-1], } base_config = super(AUC, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/python/keras/metrics_confusion_matrix_test.py b/tensorflow/python/keras/metrics_confusion_matrix_test.py index 972f7b6de7b..92c904c6aa3 100644 --- a/tensorflow/python/keras/metrics_confusion_matrix_test.py +++ b/tensorflow/python/keras/metrics_confusion_matrix_test.py @@ -963,7 +963,7 @@ class AUCTest(test.TestCase): old_config = auc_obj.get_config() self.assertDictEqual(old_config, json.loads(json.dumps(old_config))) - # Check save and restore config + # Check save and restore config. auc_obj2 = metrics.AUC.from_config(auc_obj.get_config()) self.assertEqual(auc_obj2.name, 'auc_1') self.assertEqual(len(auc_obj2.variables), 4) @@ -973,6 +973,36 @@ class AUCTest(test.TestCase): metrics_utils.AUCSummationMethod.MAJORING) new_config = auc_obj2.get_config() self.assertDictEqual(old_config, new_config) + self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds) + + def test_config_manual_thresholds(self): + auc_obj = metrics.AUC( + num_thresholds=None, + curve='PR', + summation_method='majoring', + name='auc_1', + thresholds=[0.3, 0.5]) + self.assertEqual(auc_obj.name, 'auc_1') + self.assertEqual(len(auc_obj.variables), 4) + self.assertEqual(auc_obj.num_thresholds, 4) + self.assertAllClose(auc_obj.thresholds, [0.0, 0.3, 0.5, 1.0]) + self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR) + self.assertEqual(auc_obj.summation_method, + metrics_utils.AUCSummationMethod.MAJORING) + old_config = auc_obj.get_config() + self.assertDictEqual(old_config, json.loads(json.dumps(old_config))) + + # Check save and restore config. + auc_obj2 = metrics.AUC.from_config(auc_obj.get_config()) + self.assertEqual(auc_obj2.name, 'auc_1') + self.assertEqual(len(auc_obj2.variables), 4) + self.assertEqual(auc_obj2.num_thresholds, 4) + self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR) + self.assertEqual(auc_obj2.summation_method, + metrics_utils.AUCSummationMethod.MAJORING) + new_config = auc_obj2.get_config() + self.assertDictEqual(old_config, new_config) + self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds) def test_value_is_idempotent(self): self.setup() @@ -1010,6 +1040,23 @@ class AUCTest(test.TestCase): expected_result = (0.75 * 1 + 0.25 * 0) self.assertAllClose(self.evaluate(result), expected_result, 1e-3) + def test_manual_thresholds(self): + self.setup() + # Verify that when specified, thresholds are used instead of num_thresholds. + auc_obj = metrics.AUC(num_thresholds=2, thresholds=[0.5]) + self.assertEqual(auc_obj.num_thresholds, 3) + self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0]) + self.evaluate(variables.variables_initializer(auc_obj.variables)) + result = auc_obj(self.y_true, self.y_pred) + + # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] + # recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0] + # fp_rate = [2/2, 0, 0] = [1, 0, 0] + # heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = (0.75 * 1 + 0.25 * 0) + self.assertAllClose(self.evaluate(result), expected_result, 1e-3) + def test_weighted_roc_interpolation(self): self.setup() auc_obj = metrics.AUC(num_thresholds=self.num_thresholds) diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index d210f0ebeea..89485b28ee8 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -1945,6 +1945,21 @@ class ResetStatesTest(keras_parameterized.TestCase): self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.) self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.) + def test_reset_states_auc_manual_thresholds(self): + auc_obj = metrics.AUC(thresholds=[0.5]) + model = _get_model([auc_obj]) + x = np.concatenate((np.ones((25, 4)), np.zeros((25, 4)), np.zeros((25, 4)), + np.ones((25, 4)))) + y = np.concatenate((np.ones((25, 1)), np.zeros((25, 1)), np.ones((25, 1)), + np.zeros((25, 1)))) + + for _ in range(2): + model.evaluate(x, y) + self.assertEqual(self.evaluate(auc_obj.true_positives[1]), 25.) + self.assertEqual(self.evaluate(auc_obj.false_positives[1]), 25.) + self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.) + self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.) + def test_reset_states_mean_iou(self): m_obj = metrics.MeanIoU(num_classes=2) model = _get_model([m_obj]) diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py index 64dd5914552..718c2fe7b1d 100644 --- a/tensorflow/python/kernel_tests/metrics_test.py +++ b/tensorflow/python/kernel_tests/metrics_test.py @@ -1193,6 +1193,34 @@ class AUCTest(test.TestCase): self.assertAlmostEqual(0.7, auc.eval(), 5) + @test_util.run_deprecated_v1 + def testManualThresholds(self): + with self.cached_session(): + # Verifies that thresholds passed in to the `thresholds` parameter are + # used correctly. + # The default thresholds do not split the second and third predictions. + # Thus, when we provide manual thresholds which correctly split it, we get + # an accurate AUC value. + predictions = constant_op.constant( + [0.12, 0.3001, 0.3003, 0.72], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 0, 1], shape=(1, 4)) + weights = constant_op.constant([1, 1, 1, 1], shape=(1, 4)) + thresholds = [0.0, 0.2, 0.3002, 0.6, 1.0] + default_auc, default_update_op = metrics.auc(labels, + predictions, + weights=weights) + manual_auc, manual_update_op = metrics.auc(labels, + predictions, + weights=weights, + thresholds=thresholds) + + self.evaluate(variables.local_variables_initializer()) + self.assertAlmostEqual(0.875, self.evaluate(default_update_op), 3) + self.assertAlmostEqual(0.875, default_auc.eval(), 3) + + self.assertAlmostEqual(0.75, self.evaluate(manual_update_op), 3) + self.assertAlmostEqual(0.75, manual_auc.eval(), 3) + # Regarding the AUC-PR tests: note that the preferred method when # calculating AUC-PR is summation_method='careful_interpolation'. @test_util.run_deprecated_v1 diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 9ac0cc43b83..e3292e081fe 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -635,7 +635,8 @@ def auc(labels, updates_collections=None, curve='ROC', name=None, - summation_method='trapezoidal'): + summation_method='trapezoidal', + thresholds=None): """Computes the approximate AUC via a Riemann sum. The `auc` function creates four local variables, `true_positives`, @@ -657,7 +658,9 @@ def auc(labels, in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor if this is not the case. Setting `summation_method` to 'minoring' or 'majoring' can help quantify the error in the approximation - by providing lower or upper bound estimate of the AUC. + by providing lower or upper bound estimate of the AUC. The `thresholds` + parameter can be used to manually specify thresholds which split the + predictions more evenly. For estimation of the metric over a stream of data, the function creates an `update_op` operation that updates these variables and returns the `auc`. @@ -691,6 +694,12 @@ def auc(labels, Note that 'careful_interpolation' is strictly preferred to 'trapezoidal' (to be deprecated soon) as it applies the same method for ROC, and a better one (see Davis & Goadrich 2006 for details) for the PR curve. + thresholds: An optional list of floating point values to use as the + thresholds for discretizing the curve. If set, the `num_thresholds` + parameter is ignored. Values should be in [0, 1]. Endpoint thresholds + equal to {-epsilon, 1+epsilon} for a small positive epsilon value will be + automatically included with these to correctly handle predictions equal to + exactly 0 or 1. Returns: auc: A scalar `Tensor` representing the current area-under-curve. @@ -713,10 +722,20 @@ def auc(labels, (labels, predictions, weights)): if curve != 'ROC' and curve != 'PR': raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) - kepsilon = 1e-7 # to account for floating point imprecisions - thresholds = [ - (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) - ] + + kepsilon = 1e-7 # To account for floating point imprecisions. + if thresholds is not None: + # If specified, use the supplied thresholds. + thresholds = sorted(thresholds) + num_thresholds = len(thresholds) + 2 + else: + # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in + # (0, 1). + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + + # Add an endpoint "threshold" below zero and above one for either threshold + # method. thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] values, update_ops = _confusion_matrix_at_thresholds( diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt index 539521d8c40..fb80bed0a37 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt @@ -91,7 +91,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.metrics.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.pbtxt index e9b996c9f53..719ff56452d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.metrics.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.pbtxt @@ -6,7 +6,7 @@ tf_module { } member_method { name: "auc" - argspec: "args=[\'labels\', \'predictions\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'curve\', \'name\', \'summation_method\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'ROC\', \'None\', \'trapezoidal\'], " + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'curve\', \'name\', \'summation_method\', \'thresholds\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'ROC\', \'None\', \'trapezoidal\', \'None\'], " } member_method { name: "average_precision_at_k" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt index 539521d8c40..fb80bed0a37 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt @@ -91,7 +91,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt index e82ba1084ab..f79a2343baa 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt @@ -91,7 +91,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss"