Fix a bug of overestimating AUC_PR. When TP and FP are both 0s, the precision should be 0 instead of 1.

PiperOrigin-RevId: 185842713
This commit is contained in:
A. Unique TensorFlower 2018-02-15 08:26:14 -08:00 committed by TensorFlower Gardener
parent cf74c749aa
commit c356d28001
10 changed files with 62 additions and 43 deletions

View File

@ -446,7 +446,7 @@ class MultiLabelHead(test.TestCase):
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
keys.AUC_PR: 0.5972,
}
self._test_eval(
head=head,
@ -478,7 +478,7 @@ class MultiLabelHead(test.TestCase):
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
keys.AUC_PR: 0.5972,
}
self._test_eval(
head=head,
@ -509,7 +509,7 @@ class MultiLabelHead(test.TestCase):
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
keys.AUC_PR: 0.5972,
}
self._test_eval(
head=head,
@ -543,7 +543,7 @@ class MultiLabelHead(test.TestCase):
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
keys.AUC_PR: 0.5972,
}
self._test_eval(
head=head,
@ -573,7 +573,7 @@ class MultiLabelHead(test.TestCase):
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
keys.AUC_PR: 0.5972,
keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4.,
keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3.,
keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3.,
@ -621,7 +621,7 @@ class MultiLabelHead(test.TestCase):
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.2000,
keys.AUC_PR: 0.7833,
keys.AUC_PR: 0.5833,
}
# Assert spec contains expected tensors.
@ -1095,7 +1095,7 @@ class MultiLabelHead(test.TestCase):
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.4977,
keys.AUC_PR: 0.6645,
keys.AUC_PR: 0.4037,
}
self._test_eval(
head=head,

View File

@ -306,8 +306,8 @@ class MultiHeadTest(test.TestCase):
# this assert tests that the algorithm remains consistent.
keys.AUC + '/head1': 0.1667,
keys.AUC + '/head2': 0.3333,
keys.AUC_PR + '/head1': 0.6667,
keys.AUC_PR + '/head2': 0.5000,
keys.AUC_PR + '/head1': 0.49999964,
keys.AUC_PR + '/head2': 0.33333313,
}
# Assert spec contains expected tensors.

View File

@ -362,7 +362,7 @@ class MultiLabelHeadTest(test.TestCase):
"auc_precision_recall": 0.166667,
"auc_precision_recall/class0": 0,
"auc_precision_recall/class1": 0.,
"auc_precision_recall/class2": 1.,
"auc_precision_recall/class2": 0.49999,
"labels/actual_label_mean/class0": self._labels[0][0],
"labels/actual_label_mean/class1": self._labels[0][1],
"labels/actual_label_mean/class2": self._labels[0][2],
@ -748,7 +748,7 @@ class BinaryClassificationHeadTest(test.TestCase):
"accuracy/baseline_label_mean": label_mean,
"accuracy/threshold_0.500000_mean": 1. / 2,
"auc": 1. / 2,
"auc_precision_recall": 0.749999,
"auc_precision_recall": 0.25,
"labels/actual_label_mean": label_mean,
"labels/prediction_mean": .731059, # softmax
"loss": expected_loss,

View File

@ -1802,9 +1802,9 @@ class StreamingAUCTest(test.TestCase):
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.54166603, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
self.assertAlmostEqual(0.54166603, auc.eval(), delta=1e-3)
def testAnotherAUCPRSpecialCase(self):
with self.test_session() as sess:
@ -1816,9 +1816,9 @@ class StreamingAUCTest(test.TestCase):
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3)
def testThirdAUCPRSpecialCase(self):
with self.test_session() as sess:
@ -1830,9 +1830,9 @@ class StreamingAUCTest(test.TestCase):
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3)
self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3)
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
@ -1865,9 +1865,9 @@ class StreamingAUCTest(test.TestCase):
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op), 6)
self.assertAlmostEqual(0.49999976, sess.run(update_op), 6)
self.assertAlmostEqual(1, auc.eval(), 6)
self.assertAlmostEqual(0.49999976, auc.eval(), 6)
def testWithMultipleUpdates(self):
num_samples = 1000
@ -6689,7 +6689,8 @@ class CohenKappaTest(test.TestCase):
# [[0, 25, 0],
# [0, 0, 25],
# [25, 0, 0]]
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions)
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(
# labels, predictions)
expect = -0.333333333333
with self.test_session() as sess:
@ -6748,7 +6749,8 @@ class CohenKappaTest(test.TestCase):
weights_t: weights[batch_start:batch_end]
})
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(
# labels_np, predictions_np, sample_weight=weights_np)
# labels_np, predictions_np,
# sample_weight=weights_np)
expect = 0.289965397924
self.assertAlmostEqual(expect, kappa.eval(), 5)

View File

@ -1075,7 +1075,7 @@ class BaselineClassifierEvaluationTest(test.TestCase):
metric_keys.MetricKeys.LABEL_MEAN: 1.,
metric_keys.MetricKeys.ACCURACY_BASELINE: 1,
metric_keys.MetricKeys.AUC: 0.,
metric_keys.MetricKeys.AUC_PR: 1.,
metric_keys.MetricKeys.AUC_PR: 0.5,
}
else:
# Multi classes: loss = 1 * -log ( softmax(logits)[label] )
@ -1136,7 +1136,7 @@ class BaselineClassifierEvaluationTest(test.TestCase):
metric_keys.MetricKeys.LABEL_MEAN: 0.5,
metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
metric_keys.MetricKeys.AUC: 0.5,
metric_keys.MetricKeys.AUC_PR: 0.75,
metric_keys.MetricKeys.AUC_PR: 0.25,
}
else:
# Expand logits since batch_size=2
@ -1212,7 +1212,7 @@ class BaselineClassifierEvaluationTest(test.TestCase):
metric_keys.MetricKeys.ACCURACY_BASELINE: (
max(label_mean, 1-label_mean)),
metric_keys.MetricKeys.AUC: 0.5,
metric_keys.MetricKeys.AUC_PR: 2. / (1. + 2.),
metric_keys.MetricKeys.AUC_PR: 0.16666645,
}
else:
# Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] )

View File

@ -1041,7 +1041,7 @@ class BaseDNNClassifierEvaluateTest(object):
# There is no good way to calculate AUC for only two data points. But
# that is what the algorithm returns.
metric_keys.MetricKeys.AUC: 0.5,
metric_keys.MetricKeys.AUC_PR: 0.75,
metric_keys.MetricKeys.AUC_PR: 0.25,
ops.GraphKeys.GLOBAL_STEP: global_step
}, dnn_classifier.evaluate(input_fn=_input_fn, steps=1))

View File

@ -1558,7 +1558,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
keys.LABEL_MEAN: 2./2,
keys.ACCURACY_BASELINE: 2./2,
keys.AUC: 0.,
keys.AUC_PR: 1.,
keys.AUC_PR: 0.74999905,
}
# Assert spec contains expected tensors.
@ -1636,7 +1636,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
keys.LABEL_MEAN: 2./2,
keys.ACCURACY_BASELINE: 2./2,
keys.AUC: 0.,
keys.AUC_PR: 1.,
keys.AUC_PR: 0.75,
}
# Assert predictions, loss, and metrics.
@ -1741,7 +1741,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
keys.LABEL_MEAN: 2./2,
keys.ACCURACY_BASELINE: 2./2,
keys.AUC: 0.,
keys.AUC_PR: 1.,
keys.AUC_PR: 0.74999905,
keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 1.,
keys.PRECISION_AT_THRESHOLD % thresholds[0]: 1.,
keys.RECALL_AT_THRESHOLD % thresholds[0]: 1.,
@ -2188,7 +2188,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
keys.LABEL_MEAN: expected_label_mean,
keys.ACCURACY_BASELINE: 1 - expected_label_mean,
keys.AUC: .45454565,
keys.AUC_PR: .6737757325172424,
keys.AUC_PR: .21923049,
}
# Assert spec contains expected tensors.
@ -2487,7 +2487,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
# We cannot reliably calculate AUC with only 4 data points, but the
# values should not change because of backwards-compatibility.
keys.AUC: 0.5222,
keys.AUC_PR: 0.7341,
keys.AUC_PR: 0.5119,
}
tol = 1e-2

View File

@ -1342,7 +1342,7 @@ class BaseLinearClassifierEvaluationTest(object):
metric_keys.MetricKeys.LABEL_MEAN: 1.,
metric_keys.MetricKeys.ACCURACY_BASELINE: 1,
metric_keys.MetricKeys.AUC: 0.,
metric_keys.MetricKeys.AUC_PR: 1.,
metric_keys.MetricKeys.AUC_PR: 0.5,
}
else:
# Multi classes: loss = 1 * -log ( soft_max(logits)[label] )

View File

@ -1105,9 +1105,9 @@ class AUCTest(test.TestCase):
auc, update_op = metrics.auc(labels, predictions, curve='PR')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.54166, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
self.assertAlmostEqual(0.54166, auc.eval(), delta=1e-3)
def testAnotherAUCPRSpecialCase(self):
with self.test_session() as sess:
@ -1119,9 +1119,9 @@ class AUCTest(test.TestCase):
auc, update_op = metrics.auc(labels, predictions, curve='PR')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3)
def testThirdAUCPRSpecialCase(self):
with self.test_session() as sess:
@ -1133,9 +1133,26 @@ class AUCTest(test.TestCase):
auc, update_op = metrics.auc(labels, predictions, curve='PR')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3)
self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3)
def testFourthAUCPRSpecialCase(self):
# Create the labels and data.
labels = np.array([
0, 0, 0, 0, 0, 0, 0, 1, 0, 1])
predictions = np.array([
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35])
with self.test_session() as sess:
auc, _ = metrics.auc(
labels, predictions, curve='PR', num_thresholds=11)
sess.run(variables.local_variables_initializer())
# Since this is only approximate, we can't expect a 6 digits match.
# Although with higher number of samples/thresholds we should see the
# accuracy improving
self.assertAlmostEqual(0.0, auc.eval(), delta=0.001)
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
@ -1161,16 +1178,16 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(1, auc.eval(), 6)
def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
def testRecallOneAndPrecisionOne(self):
with self.test_session() as sess:
predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
labels = array_ops.ones([4])
auc, update_op = metrics.auc(labels, predictions, curve='PR')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op), 6)
self.assertAlmostEqual(0.5, sess.run(update_op), 6)
self.assertAlmostEqual(1, auc.eval(), 6)
self.assertAlmostEqual(0.5, auc.eval(), 6)
def np_auc(self, predictions, labels, weights):
"""Computes the AUC explicitly using Numpy.

View File

@ -672,7 +672,7 @@ def auc(labels,
x = fp_rate
y = rec
else: # curve == 'PR'.
prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
prec = math_ops.div(tp, tp + fp + epsilon)
x = rec
y = prec
if summation_method == 'trapezoidal':
@ -923,8 +923,8 @@ def mean_per_class_accuracy(labels,
weights = array_ops.reshape(weights, [-1])
weights = math_ops.to_float(weights)
is_correct = is_correct * weights
ones = ones * weights
is_correct *= weights
ones *= weights
update_total_op = state_ops.scatter_add(total, labels, ones)
update_count_op = state_ops.scatter_add(count, labels, is_correct)