Add the ability to specify custom thresholds for AUC computation. As noted in the original documentation, the metric uses thresholds distributed uniformly over [0, 1] by default, which is not accurate in cases where predictions are not also distributed somewhat uniformly over that range.

Currently, if a client discovers that the lower and upper bound estimates of AUC are not accurate by using `summation_methods` 'minoring' and 'majoring', their only recourse is to increase `num_thresholds`. If predictions are peaked in a very narrow range, `num_thresholds` may need to be set very high in order to have any resolution, substantially increasing the size of the required local variables. The new optional `thresholds` parameter allows clients to manually specify the thresholds to use instead, which allows for higher metric resolution at equal variable cost.

PiperOrigin-RevId: 240459959
This commit is contained in:
A. Unique TensorFlower 2019-03-26 17:09:05 -07:00 committed by TensorFlower Gardener
parent 919b38007e
commit d7ded17058
9 changed files with 157 additions and 29 deletions

View File

@ -1570,7 +1570,8 @@ class AUC(Metric):
(computed using the aforementioned variables). The `num_thresholds` variable (computed using the aforementioned variables). The `num_thresholds` variable
controls the degree of discretization with larger numbers of thresholds more controls the degree of discretization with larger numbers of thresholds more
closely approximating the true AUC. The quality of the approximation may vary closely approximating the true AUC. The quality of the approximation may vary
dramatically depending on `num_thresholds`. dramatically depending on `num_thresholds`. The `thresholds` parameter can be
used to manually specify thresholds which split the predictions more evenly.
For best results, `predictions` should be distributed approximately uniformly For best results, `predictions` should be distributed approximately uniformly
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
@ -1608,7 +1609,8 @@ class AUC(Metric):
curve='ROC', curve='ROC',
summation_method='interpolation', summation_method='interpolation',
name=None, name=None,
dtype=None): dtype=None,
thresholds=None):
"""Creates an `AUC` instance. """Creates an `AUC` instance.
Args: Args:
@ -1625,10 +1627,14 @@ class AUC(Metric):
'majoring' that does the opposite. 'majoring' that does the opposite.
name: (Optional) string name of the metric instance. name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result. dtype: (Optional) data type of the metric result.
thresholds: (Optional) A list of floating point values to use as the
thresholds for discretizing the curve. If set, the `num_thresholds`
parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
equal to {-epsilon, 1+epsilon} for a small positive epsilon value will
be automatically included with these to correctly handle predictions
equal to exactly 0 or 1.
""" """
# Validate configurations. # Validate configurations.
if num_thresholds <= 1:
raise ValueError('`num_thresholds` must be > 1.')
if isinstance(curve, metrics_utils.AUCCurve) and curve not in list( if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
metrics_utils.AUCCurve): metrics_utils.AUCCurve):
raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format( raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format(
@ -1642,7 +1648,24 @@ class AUC(Metric):
summation_method, list(metrics_utils.AUCSummationMethod))) summation_method, list(metrics_utils.AUCSummationMethod)))
# Update properties. # Update properties.
if thresholds is not None:
# If specified, use the supplied thresholds.
self.num_thresholds = len(thresholds) + 2
thresholds = sorted(thresholds)
else:
if num_thresholds <= 1:
raise ValueError('`num_thresholds` must be > 1.')
# Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
# (0, 1).
self.num_thresholds = num_thresholds self.num_thresholds = num_thresholds
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
# Add an endpoint "threshold" below zero and above one for either
# threshold method to account for floating point imprecisions.
self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()]
if isinstance(curve, metrics_utils.AUCCurve): if isinstance(curve, metrics_utils.AUCCurve):
self.curve = curve self.curve = curve
else: else:
@ -1657,28 +1680,21 @@ class AUC(Metric):
# Create metric variables # Create metric variables
self.true_positives = self.add_weight( self.true_positives = self.add_weight(
'true_positives', 'true_positives',
shape=(num_thresholds,), shape=(self.num_thresholds,),
initializer=init_ops.zeros_initializer) initializer=init_ops.zeros_initializer)
self.true_negatives = self.add_weight( self.true_negatives = self.add_weight(
'true_negatives', 'true_negatives',
shape=(num_thresholds,), shape=(self.num_thresholds,),
initializer=init_ops.zeros_initializer) initializer=init_ops.zeros_initializer)
self.false_positives = self.add_weight( self.false_positives = self.add_weight(
'false_positives', 'false_positives',
shape=(num_thresholds,), shape=(self.num_thresholds,),
initializer=init_ops.zeros_initializer) initializer=init_ops.zeros_initializer)
self.false_negatives = self.add_weight( self.false_negatives = self.add_weight(
'false_negatives', 'false_negatives',
shape=(num_thresholds,), shape=(self.num_thresholds,),
initializer=init_ops.zeros_initializer) initializer=init_ops.zeros_initializer)
# Compute `num_thresholds` thresholds in [0, 1]
thresholds = [
(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
]
self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()]
# epsilon - to account for floating point imprecisions.
def update_state(self, y_true, y_pred, sample_weight=None): def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates confusion matrix statistics. """Accumulates confusion matrix statistics.
@ -1804,15 +1820,18 @@ class AUC(Metric):
name=self.name) name=self.name)
def reset_states(self): def reset_states(self):
num_thresholds = len(self.thresholds)
K.batch_set_value( K.batch_set_value(
[(v, np.zeros((num_thresholds,))) for v in self.variables]) [(v, np.zeros((self.num_thresholds,))) for v in self.variables])
def get_config(self): def get_config(self):
config = { config = {
'num_thresholds': self.num_thresholds, 'num_thresholds': self.num_thresholds,
'curve': self.curve.value, 'curve': self.curve.value,
'summation_method': self.summation_method.value, 'summation_method': self.summation_method.value,
# We remove the endpoint thresholds as an inverse of how the thresholds
# were initialized. This ensures that a metric initialized from this
# config has the same thresholds.
'thresholds': self.thresholds[1:-1],
} }
base_config = super(AUC, self).get_config() base_config = super(AUC, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))

View File

@ -963,7 +963,7 @@ class AUCTest(test.TestCase):
old_config = auc_obj.get_config() old_config = auc_obj.get_config()
self.assertDictEqual(old_config, json.loads(json.dumps(old_config))) self.assertDictEqual(old_config, json.loads(json.dumps(old_config)))
# Check save and restore config # Check save and restore config.
auc_obj2 = metrics.AUC.from_config(auc_obj.get_config()) auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
self.assertEqual(auc_obj2.name, 'auc_1') self.assertEqual(auc_obj2.name, 'auc_1')
self.assertEqual(len(auc_obj2.variables), 4) self.assertEqual(len(auc_obj2.variables), 4)
@ -973,6 +973,36 @@ class AUCTest(test.TestCase):
metrics_utils.AUCSummationMethod.MAJORING) metrics_utils.AUCSummationMethod.MAJORING)
new_config = auc_obj2.get_config() new_config = auc_obj2.get_config()
self.assertDictEqual(old_config, new_config) self.assertDictEqual(old_config, new_config)
self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds)
def test_config_manual_thresholds(self):
auc_obj = metrics.AUC(
num_thresholds=None,
curve='PR',
summation_method='majoring',
name='auc_1',
thresholds=[0.3, 0.5])
self.assertEqual(auc_obj.name, 'auc_1')
self.assertEqual(len(auc_obj.variables), 4)
self.assertEqual(auc_obj.num_thresholds, 4)
self.assertAllClose(auc_obj.thresholds, [0.0, 0.3, 0.5, 1.0])
self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)
self.assertEqual(auc_obj.summation_method,
metrics_utils.AUCSummationMethod.MAJORING)
old_config = auc_obj.get_config()
self.assertDictEqual(old_config, json.loads(json.dumps(old_config)))
# Check save and restore config.
auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
self.assertEqual(auc_obj2.name, 'auc_1')
self.assertEqual(len(auc_obj2.variables), 4)
self.assertEqual(auc_obj2.num_thresholds, 4)
self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)
self.assertEqual(auc_obj2.summation_method,
metrics_utils.AUCSummationMethod.MAJORING)
new_config = auc_obj2.get_config()
self.assertDictEqual(old_config, new_config)
self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds)
def test_value_is_idempotent(self): def test_value_is_idempotent(self):
self.setup() self.setup()
@ -1010,6 +1040,23 @@ class AUCTest(test.TestCase):
expected_result = (0.75 * 1 + 0.25 * 0) expected_result = (0.75 * 1 + 0.25 * 0)
self.assertAllClose(self.evaluate(result), expected_result, 1e-3) self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_manual_thresholds(self):
self.setup()
# Verify that when specified, thresholds are used instead of num_thresholds.
auc_obj = metrics.AUC(num_thresholds=2, thresholds=[0.5])
self.assertEqual(auc_obj.num_thresholds, 3)
self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0])
self.evaluate(variables.variables_initializer(auc_obj.variables))
result = auc_obj(self.y_true, self.y_pred)
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
# recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
# fp_rate = [2/2, 0, 0] = [1, 0, 0]
# heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
# widths = [(1 - 0), (0 - 0)] = [1, 0]
expected_result = (0.75 * 1 + 0.25 * 0)
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
def test_weighted_roc_interpolation(self): def test_weighted_roc_interpolation(self):
self.setup() self.setup()
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds) auc_obj = metrics.AUC(num_thresholds=self.num_thresholds)

View File

@ -1945,6 +1945,21 @@ class ResetStatesTest(keras_parameterized.TestCase):
self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.) self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.) self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
def test_reset_states_auc_manual_thresholds(self):
auc_obj = metrics.AUC(thresholds=[0.5])
model = _get_model([auc_obj])
x = np.concatenate((np.ones((25, 4)), np.zeros((25, 4)), np.zeros((25, 4)),
np.ones((25, 4))))
y = np.concatenate((np.ones((25, 1)), np.zeros((25, 1)), np.ones((25, 1)),
np.zeros((25, 1))))
for _ in range(2):
model.evaluate(x, y)
self.assertEqual(self.evaluate(auc_obj.true_positives[1]), 25.)
self.assertEqual(self.evaluate(auc_obj.false_positives[1]), 25.)
self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
def test_reset_states_mean_iou(self): def test_reset_states_mean_iou(self):
m_obj = metrics.MeanIoU(num_classes=2) m_obj = metrics.MeanIoU(num_classes=2)
model = _get_model([m_obj]) model = _get_model([m_obj])

View File

@ -1193,6 +1193,34 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0.7, auc.eval(), 5) self.assertAlmostEqual(0.7, auc.eval(), 5)
@test_util.run_deprecated_v1
def testManualThresholds(self):
with self.cached_session():
# Verifies that thresholds passed in to the `thresholds` parameter are
# used correctly.
# The default thresholds do not split the second and third predictions.
# Thus, when we provide manual thresholds which correctly split it, we get
# an accurate AUC value.
predictions = constant_op.constant(
[0.12, 0.3001, 0.3003, 0.72], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
weights = constant_op.constant([1, 1, 1, 1], shape=(1, 4))
thresholds = [0.0, 0.2, 0.3002, 0.6, 1.0]
default_auc, default_update_op = metrics.auc(labels,
predictions,
weights=weights)
manual_auc, manual_update_op = metrics.auc(labels,
predictions,
weights=weights,
thresholds=thresholds)
self.evaluate(variables.local_variables_initializer())
self.assertAlmostEqual(0.875, self.evaluate(default_update_op), 3)
self.assertAlmostEqual(0.875, default_auc.eval(), 3)
self.assertAlmostEqual(0.75, self.evaluate(manual_update_op), 3)
self.assertAlmostEqual(0.75, manual_auc.eval(), 3)
# Regarding the AUC-PR tests: note that the preferred method when # Regarding the AUC-PR tests: note that the preferred method when
# calculating AUC-PR is summation_method='careful_interpolation'. # calculating AUC-PR is summation_method='careful_interpolation'.
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1

View File

@ -635,7 +635,8 @@ def auc(labels,
updates_collections=None, updates_collections=None,
curve='ROC', curve='ROC',
name=None, name=None,
summation_method='trapezoidal'): summation_method='trapezoidal',
thresholds=None):
"""Computes the approximate AUC via a Riemann sum. """Computes the approximate AUC via a Riemann sum.
The `auc` function creates four local variables, `true_positives`, The `auc` function creates four local variables, `true_positives`,
@ -657,7 +658,9 @@ def auc(labels,
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
approximation may be poor if this is not the case. Setting `summation_method` approximation may be poor if this is not the case. Setting `summation_method`
to 'minoring' or 'majoring' can help quantify the error in the approximation to 'minoring' or 'majoring' can help quantify the error in the approximation
by providing lower or upper bound estimate of the AUC. by providing lower or upper bound estimate of the AUC. The `thresholds`
parameter can be used to manually specify thresholds which split the
predictions more evenly.
For estimation of the metric over a stream of data, the function creates an For estimation of the metric over a stream of data, the function creates an
`update_op` operation that updates these variables and returns the `auc`. `update_op` operation that updates these variables and returns the `auc`.
@ -691,6 +694,12 @@ def auc(labels,
Note that 'careful_interpolation' is strictly preferred to 'trapezoidal' Note that 'careful_interpolation' is strictly preferred to 'trapezoidal'
(to be deprecated soon) as it applies the same method for ROC, and a (to be deprecated soon) as it applies the same method for ROC, and a
better one (see Davis & Goadrich 2006 for details) for the PR curve. better one (see Davis & Goadrich 2006 for details) for the PR curve.
thresholds: An optional list of floating point values to use as the
thresholds for discretizing the curve. If set, the `num_thresholds`
parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
equal to {-epsilon, 1+epsilon} for a small positive epsilon value will be
automatically included with these to correctly handle predictions equal to
exactly 0 or 1.
Returns: Returns:
auc: A scalar `Tensor` representing the current area-under-curve. auc: A scalar `Tensor` representing the current area-under-curve.
@ -713,10 +722,20 @@ def auc(labels,
(labels, predictions, weights)): (labels, predictions, weights)):
if curve != 'ROC' and curve != 'PR': if curve != 'ROC' and curve != 'PR':
raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [ kepsilon = 1e-7 # To account for floating point imprecisions.
(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) if thresholds is not None:
] # If specified, use the supplied thresholds.
thresholds = sorted(thresholds)
num_thresholds = len(thresholds) + 2
else:
# Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
# (0, 1).
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
# Add an endpoint "threshold" below zero and above one for either threshold
# method.
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _confusion_matrix_at_thresholds( values, update_ops = _confusion_matrix_at_thresholds(

View File

@ -91,7 +91,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\'], " argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "add_loss" name: "add_loss"

View File

@ -6,7 +6,7 @@ tf_module {
} }
member_method { member_method {
name: "auc" name: "auc"
argspec: "args=[\'labels\', \'predictions\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'curve\', \'name\', \'summation_method\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'ROC\', \'None\', \'trapezoidal\'], " argspec: "args=[\'labels\', \'predictions\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'curve\', \'name\', \'summation_method\', \'thresholds\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'ROC\', \'None\', \'trapezoidal\', \'None\'], "
} }
member_method { member_method {
name: "average_precision_at_k" name: "average_precision_at_k"

View File

@ -91,7 +91,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\'], " argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "add_loss" name: "add_loss"

View File

@ -91,7 +91,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\'], " argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "add_loss" name: "add_loss"