Add the ability to specify custom thresholds for AUC computation. As noted in the original documentation, the metric uses thresholds distributed uniformly over [0, 1] by default, which is not accurate in cases where predictions are not also distributed somewhat uniformly over that range.

Currently, if a client discovers that the lower and upper bound estimates of AUC are not accurate by using `summation_methods` 'minoring' and 'majoring', their only recourse is to increase `num_thresholds`. If predictions are peaked in a very narrow range, `num_thresholds` may need to be set very high in order to have any resolution, substantially increasing the size of the required local variables. The new optional `thresholds` parameter allows clients to manually specify the thresholds to use instead, which allows for higher metric resolution at equal variable cost.

PiperOrigin-RevId: 240459959
This commit is contained in:
A. Unique TensorFlower 2019-03-26 17:09:05 -07:00 committed by TensorFlower Gardener
parent 919b38007e
commit d7ded17058
9 changed files with 157 additions and 29 deletions

View File

@ -1570,7 +1570,8 @@ class AUC(Metric):
(computed using the aforementioned variables). The `num_thresholds` variable
controls the degree of discretization with larger numbers of thresholds more
closely approximating the true AUC. The quality of the approximation may vary
dramatically depending on `num_thresholds`.
dramatically depending on `num_thresholds`. The `thresholds` parameter can be
used to manually specify thresholds which split the predictions more evenly.
For best results, `predictions` should be distributed approximately uniformly
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
@ -1608,7 +1609,8 @@ class AUC(Metric):
curve='ROC',
summation_method='interpolation',
name=None,
dtype=None):
dtype=None,
thresholds=None):
"""Creates an `AUC` instance.
Args:
@ -1625,10 +1627,14 @@ class AUC(Metric):
'majoring' that does the opposite.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
thresholds: (Optional) A list of floating point values to use as the
thresholds for discretizing the curve. If set, the `num_thresholds`
parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
equal to {-epsilon, 1+epsilon} for a small positive epsilon value will
be automatically included with these to correctly handle predictions
equal to exactly 0 or 1.
"""
# Validate configurations.
if num_thresholds <= 1:
raise ValueError('`num_thresholds` must be > 1.')
if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
metrics_utils.AUCCurve):
raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format(
@ -1642,7 +1648,24 @@ class AUC(Metric):
summation_method, list(metrics_utils.AUCSummationMethod)))
# Update properties.
self.num_thresholds = num_thresholds
if thresholds is not None:
# If specified, use the supplied thresholds.
self.num_thresholds = len(thresholds) + 2
thresholds = sorted(thresholds)
else:
if num_thresholds <= 1:
raise ValueError('`num_thresholds` must be > 1.')
# Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
# (0, 1).
self.num_thresholds = num_thresholds
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
# Add an endpoint "threshold" below zero and above one for either
# threshold method to account for floating point imprecisions.
self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()]
if isinstance(curve, metrics_utils.AUCCurve):
self.curve = curve
else:
@ -1657,28 +1680,21 @@ class AUC(Metric):
# Create metric variables
self.true_positives = self.add_weight(
'true_positives',
shape=(num_thresholds,),
shape=(self.num_thresholds,),
initializer=init_ops.zeros_initializer)
self.true_negatives = self.add_weight(
'true_negatives',
shape=(num_thresholds,),
shape=(self.num_thresholds,),
initializer=init_ops.zeros_initializer)
self.false_positives = self.add_weight(
'false_positives',
shape=(num_thresholds,),
shape=(self.num_thresholds,),
initializer=init_ops.zeros_initializer)
self.false_negatives = self.add_weight(
'false_negatives',
shape=(num_thresholds,),
shape=(self.num_thresholds,),
initializer=init_ops.zeros_initializer)
# Compute `num_thresholds` thresholds in [0, 1]
thresholds = [
(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
]
self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()]
# epsilon - to account for floating point imprecisions.
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates confusion matrix statistics.
@ -1804,15 +1820,18 @@ class AUC(Metric):
name=self.name)
def reset_states(self):
num_thresholds = len(self.thresholds)
K.batch_set_value(
[(v, np.zeros((num_thresholds,))) for v in self.variables])
[(v, np.zeros((self.num_thresholds,))) for v in self.variables])
def get_config(self):
config = {
'num_thresholds': self.num_thresholds,
'curve': self.curve.value,
'summation_method': self.summation_method.value,
# We remove the endpoint thresholds as an inverse of how the thresholds
# were initialized. This ensures that a metric initialized from this
# config has the same thresholds.
'thresholds': self.thresholds[1:-1],
}
base_config = super(AUC, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

View File

@ -963,7 +963,7 @@ class AUCTest(test.TestCase):
old_config = auc_obj.get_config()
self.assertDictEqual(old_config, json.loads(json.dumps(old_config)))
# Check save and restore config
# Check save and restore config.
auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
self.assertEqual(auc_obj2.name, 'auc_1')
self.assertEqual(len(auc_obj2.variables), 4)
@ -973,6 +973,36 @@ class AUCTest(test.TestCase):
metrics_utils.AUCSummationMethod.MAJORING)
new_config = auc_obj2.get_config()
self.assertDictEqual(old_config, new_config)
self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds)
def test_config_manual_thresholds(self):
auc_obj = metrics.AUC(
num_thresholds=None,
curve='PR',
summation_method='majoring',
name='auc_1',
thresholds=[0.3, 0.5])
self.assertEqual(auc_obj.name, 'auc_1')
self.assertEqual(len(auc_obj.variables), 4)
self.assertEqual(auc_obj.num_thresholds, 4)
self.assertAllClose(auc_obj.thresholds, [0.0, 0.3, 0.5, 1.0])
self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)
self.assertEqual(auc_obj.summation_method,
metrics_utils.AUCSummationMethod.MAJORING)
old_config = auc_obj.get_config()
self.assertDictEqual(old_config, json.loads(json.dumps(old_config)))
# Check save and restore config.
auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
self.assertEqual(auc_obj2.name, 'auc_1')
self.assertEqual(len(auc_obj2.variables), 4)
self.assertEqual(auc_obj2.num_thresholds, 4)
self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)
self.assertEqual(auc_obj2.summation_method,
metrics_utils.AUCSummationMethod.MAJORING)
new_config = auc_obj2.get_config()
self.assertDictEqual(old_config, new_config)
self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds)
def test_value_is_idempotent(self):
self.setup()
@ -1010,6 +1040,23 @@ class AUCTest(test.TestCase):
expected_result = (0.75 * 1 + 0.25 * 0)
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_manual_thresholds(self):
self.setup()
# Verify that when specified, thresholds are used instead of num_thresholds.
auc_obj = metrics.AUC(num_thresholds=2, thresholds=[0.5])
self.assertEqual(auc_obj.num_thresholds, 3)
self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0])
self.evaluate(variables.variables_initializer(auc_obj.variables))
result = auc_obj(self.y_true, self.y_pred)
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
# recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
# fp_rate = [2/2, 0, 0] = [1, 0, 0]
# heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
# widths = [(1 - 0), (0 - 0)] = [1, 0]
expected_result = (0.75 * 1 + 0.25 * 0)
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_weighted_roc_interpolation(self):
self.setup()
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds)

View File

@ -1945,6 +1945,21 @@ class ResetStatesTest(keras_parameterized.TestCase):
self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
def test_reset_states_auc_manual_thresholds(self):
auc_obj = metrics.AUC(thresholds=[0.5])
model = _get_model([auc_obj])
x = np.concatenate((np.ones((25, 4)), np.zeros((25, 4)), np.zeros((25, 4)),
np.ones((25, 4))))
y = np.concatenate((np.ones((25, 1)), np.zeros((25, 1)), np.ones((25, 1)),
np.zeros((25, 1))))
for _ in range(2):
model.evaluate(x, y)
self.assertEqual(self.evaluate(auc_obj.true_positives[1]), 25.)
self.assertEqual(self.evaluate(auc_obj.false_positives[1]), 25.)
self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
def test_reset_states_mean_iou(self):
m_obj = metrics.MeanIoU(num_classes=2)
model = _get_model([m_obj])

View File

@ -1193,6 +1193,34 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0.7, auc.eval(), 5)
@test_util.run_deprecated_v1
def testManualThresholds(self):
with self.cached_session():
# Verifies that thresholds passed in to the `thresholds` parameter are
# used correctly.
# The default thresholds do not split the second and third predictions.
# Thus, when we provide manual thresholds which correctly split it, we get
# an accurate AUC value.
predictions = constant_op.constant(
[0.12, 0.3001, 0.3003, 0.72], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
weights = constant_op.constant([1, 1, 1, 1], shape=(1, 4))
thresholds = [0.0, 0.2, 0.3002, 0.6, 1.0]
default_auc, default_update_op = metrics.auc(labels,
predictions,
weights=weights)
manual_auc, manual_update_op = metrics.auc(labels,
predictions,
weights=weights,
thresholds=thresholds)
self.evaluate(variables.local_variables_initializer())
self.assertAlmostEqual(0.875, self.evaluate(default_update_op), 3)
self.assertAlmostEqual(0.875, default_auc.eval(), 3)
self.assertAlmostEqual(0.75, self.evaluate(manual_update_op), 3)
self.assertAlmostEqual(0.75, manual_auc.eval(), 3)
# Regarding the AUC-PR tests: note that the preferred method when
# calculating AUC-PR is summation_method='careful_interpolation'.
@test_util.run_deprecated_v1

View File

@ -635,7 +635,8 @@ def auc(labels,
updates_collections=None,
curve='ROC',
name=None,
summation_method='trapezoidal'):
summation_method='trapezoidal',
thresholds=None):
"""Computes the approximate AUC via a Riemann sum.
The `auc` function creates four local variables, `true_positives`,
@ -657,7 +658,9 @@ def auc(labels,
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
approximation may be poor if this is not the case. Setting `summation_method`
to 'minoring' or 'majoring' can help quantify the error in the approximation
by providing lower or upper bound estimate of the AUC.
by providing lower or upper bound estimate of the AUC. The `thresholds`
parameter can be used to manually specify thresholds which split the
predictions more evenly.
For estimation of the metric over a stream of data, the function creates an
`update_op` operation that updates these variables and returns the `auc`.
@ -691,6 +694,12 @@ def auc(labels,
Note that 'careful_interpolation' is strictly preferred to 'trapezoidal'
(to be deprecated soon) as it applies the same method for ROC, and a
better one (see Davis & Goadrich 2006 for details) for the PR curve.
thresholds: An optional list of floating point values to use as the
thresholds for discretizing the curve. If set, the `num_thresholds`
parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
equal to {-epsilon, 1+epsilon} for a small positive epsilon value will be
automatically included with these to correctly handle predictions equal to
exactly 0 or 1.
Returns:
auc: A scalar `Tensor` representing the current area-under-curve.
@ -713,10 +722,20 @@ def auc(labels,
(labels, predictions, weights)):
if curve != 'ROC' and curve != 'PR':
raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [
(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
]
kepsilon = 1e-7 # To account for floating point imprecisions.
if thresholds is not None:
# If specified, use the supplied thresholds.
thresholds = sorted(thresholds)
num_thresholds = len(thresholds) + 2
else:
# Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
# (0, 1).
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
# Add an endpoint "threshold" below zero and above one for either threshold
# method.
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _confusion_matrix_at_thresholds(

View File

@ -91,7 +91,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"

View File

@ -6,7 +6,7 @@ tf_module {
}
member_method {
name: "auc"
argspec: "args=[\'labels\', \'predictions\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'curve\', \'name\', \'summation_method\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'ROC\', \'None\', \'trapezoidal\'], "
argspec: "args=[\'labels\', \'predictions\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'curve\', \'name\', \'summation_method\', \'thresholds\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'ROC\', \'None\', \'trapezoidal\', \'None\'], "
}
member_method {
name: "average_precision_at_k"

View File

@ -91,7 +91,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"

View File

@ -91,7 +91,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"