Fix a bug of overestimating AUC_PR. When TP and FP are both 0s, the precision should be 0 instead of 1.
PiperOrigin-RevId: 185842713
This commit is contained in:
parent
cf74c749aa
commit
c356d28001
@ -446,7 +446,7 @@ class MultiLabelHead(test.TestCase):
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.3333,
|
||||
keys.AUC_PR: 0.7639,
|
||||
keys.AUC_PR: 0.5972,
|
||||
}
|
||||
self._test_eval(
|
||||
head=head,
|
||||
@ -478,7 +478,7 @@ class MultiLabelHead(test.TestCase):
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.3333,
|
||||
keys.AUC_PR: 0.7639,
|
||||
keys.AUC_PR: 0.5972,
|
||||
}
|
||||
self._test_eval(
|
||||
head=head,
|
||||
@ -509,7 +509,7 @@ class MultiLabelHead(test.TestCase):
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.3333,
|
||||
keys.AUC_PR: 0.7639,
|
||||
keys.AUC_PR: 0.5972,
|
||||
}
|
||||
self._test_eval(
|
||||
head=head,
|
||||
@ -543,7 +543,7 @@ class MultiLabelHead(test.TestCase):
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.3333,
|
||||
keys.AUC_PR: 0.7639,
|
||||
keys.AUC_PR: 0.5972,
|
||||
}
|
||||
self._test_eval(
|
||||
head=head,
|
||||
@ -573,7 +573,7 @@ class MultiLabelHead(test.TestCase):
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.3333,
|
||||
keys.AUC_PR: 0.7639,
|
||||
keys.AUC_PR: 0.5972,
|
||||
keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4.,
|
||||
keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3.,
|
||||
keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3.,
|
||||
@ -621,7 +621,7 @@ class MultiLabelHead(test.TestCase):
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.2000,
|
||||
keys.AUC_PR: 0.7833,
|
||||
keys.AUC_PR: 0.5833,
|
||||
}
|
||||
|
||||
# Assert spec contains expected tensors.
|
||||
@ -1095,7 +1095,7 @@ class MultiLabelHead(test.TestCase):
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.4977,
|
||||
keys.AUC_PR: 0.6645,
|
||||
keys.AUC_PR: 0.4037,
|
||||
}
|
||||
self._test_eval(
|
||||
head=head,
|
||||
|
@ -306,8 +306,8 @@ class MultiHeadTest(test.TestCase):
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC + '/head1': 0.1667,
|
||||
keys.AUC + '/head2': 0.3333,
|
||||
keys.AUC_PR + '/head1': 0.6667,
|
||||
keys.AUC_PR + '/head2': 0.5000,
|
||||
keys.AUC_PR + '/head1': 0.49999964,
|
||||
keys.AUC_PR + '/head2': 0.33333313,
|
||||
}
|
||||
|
||||
# Assert spec contains expected tensors.
|
||||
|
@ -362,7 +362,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
"auc_precision_recall": 0.166667,
|
||||
"auc_precision_recall/class0": 0,
|
||||
"auc_precision_recall/class1": 0.,
|
||||
"auc_precision_recall/class2": 1.,
|
||||
"auc_precision_recall/class2": 0.49999,
|
||||
"labels/actual_label_mean/class0": self._labels[0][0],
|
||||
"labels/actual_label_mean/class1": self._labels[0][1],
|
||||
"labels/actual_label_mean/class2": self._labels[0][2],
|
||||
@ -748,7 +748,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
"accuracy/baseline_label_mean": label_mean,
|
||||
"accuracy/threshold_0.500000_mean": 1. / 2,
|
||||
"auc": 1. / 2,
|
||||
"auc_precision_recall": 0.749999,
|
||||
"auc_precision_recall": 0.25,
|
||||
"labels/actual_label_mean": label_mean,
|
||||
"labels/prediction_mean": .731059, # softmax
|
||||
"loss": expected_loss,
|
||||
|
@ -1802,9 +1802,9 @@ class StreamingAUCTest(test.TestCase):
|
||||
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3)
|
||||
self.assertAlmostEqual(0.54166603, sess.run(update_op), delta=1e-3)
|
||||
|
||||
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
|
||||
self.assertAlmostEqual(0.54166603, auc.eval(), delta=1e-3)
|
||||
|
||||
def testAnotherAUCPRSpecialCase(self):
|
||||
with self.test_session() as sess:
|
||||
@ -1816,9 +1816,9 @@ class StreamingAUCTest(test.TestCase):
|
||||
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3)
|
||||
self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3)
|
||||
|
||||
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
|
||||
self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3)
|
||||
|
||||
def testThirdAUCPRSpecialCase(self):
|
||||
with self.test_session() as sess:
|
||||
@ -1830,9 +1830,9 @@ class StreamingAUCTest(test.TestCase):
|
||||
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3)
|
||||
self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3)
|
||||
|
||||
self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3)
|
||||
self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3)
|
||||
|
||||
def testAllIncorrect(self):
|
||||
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||
@ -1865,9 +1865,9 @@ class StreamingAUCTest(test.TestCase):
|
||||
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(1, sess.run(update_op), 6)
|
||||
self.assertAlmostEqual(0.49999976, sess.run(update_op), 6)
|
||||
|
||||
self.assertAlmostEqual(1, auc.eval(), 6)
|
||||
self.assertAlmostEqual(0.49999976, auc.eval(), 6)
|
||||
|
||||
def testWithMultipleUpdates(self):
|
||||
num_samples = 1000
|
||||
@ -6689,7 +6689,8 @@ class CohenKappaTest(test.TestCase):
|
||||
# [[0, 25, 0],
|
||||
# [0, 0, 25],
|
||||
# [25, 0, 0]]
|
||||
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions)
|
||||
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(
|
||||
# labels, predictions)
|
||||
expect = -0.333333333333
|
||||
|
||||
with self.test_session() as sess:
|
||||
@ -6748,7 +6749,8 @@ class CohenKappaTest(test.TestCase):
|
||||
weights_t: weights[batch_start:batch_end]
|
||||
})
|
||||
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(
|
||||
# labels_np, predictions_np, sample_weight=weights_np)
|
||||
# labels_np, predictions_np,
|
||||
# sample_weight=weights_np)
|
||||
expect = 0.289965397924
|
||||
self.assertAlmostEqual(expect, kappa.eval(), 5)
|
||||
|
||||
|
@ -1075,7 +1075,7 @@ class BaselineClassifierEvaluationTest(test.TestCase):
|
||||
metric_keys.MetricKeys.LABEL_MEAN: 1.,
|
||||
metric_keys.MetricKeys.ACCURACY_BASELINE: 1,
|
||||
metric_keys.MetricKeys.AUC: 0.,
|
||||
metric_keys.MetricKeys.AUC_PR: 1.,
|
||||
metric_keys.MetricKeys.AUC_PR: 0.5,
|
||||
}
|
||||
else:
|
||||
# Multi classes: loss = 1 * -log ( softmax(logits)[label] )
|
||||
@ -1136,7 +1136,7 @@ class BaselineClassifierEvaluationTest(test.TestCase):
|
||||
metric_keys.MetricKeys.LABEL_MEAN: 0.5,
|
||||
metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
|
||||
metric_keys.MetricKeys.AUC: 0.5,
|
||||
metric_keys.MetricKeys.AUC_PR: 0.75,
|
||||
metric_keys.MetricKeys.AUC_PR: 0.25,
|
||||
}
|
||||
else:
|
||||
# Expand logits since batch_size=2
|
||||
@ -1212,7 +1212,7 @@ class BaselineClassifierEvaluationTest(test.TestCase):
|
||||
metric_keys.MetricKeys.ACCURACY_BASELINE: (
|
||||
max(label_mean, 1-label_mean)),
|
||||
metric_keys.MetricKeys.AUC: 0.5,
|
||||
metric_keys.MetricKeys.AUC_PR: 2. / (1. + 2.),
|
||||
metric_keys.MetricKeys.AUC_PR: 0.16666645,
|
||||
}
|
||||
else:
|
||||
# Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] )
|
||||
|
@ -1041,7 +1041,7 @@ class BaseDNNClassifierEvaluateTest(object):
|
||||
# There is no good way to calculate AUC for only two data points. But
|
||||
# that is what the algorithm returns.
|
||||
metric_keys.MetricKeys.AUC: 0.5,
|
||||
metric_keys.MetricKeys.AUC_PR: 0.75,
|
||||
metric_keys.MetricKeys.AUC_PR: 0.25,
|
||||
ops.GraphKeys.GLOBAL_STEP: global_step
|
||||
}, dnn_classifier.evaluate(input_fn=_input_fn, steps=1))
|
||||
|
||||
|
@ -1558,7 +1558,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
|
||||
keys.LABEL_MEAN: 2./2,
|
||||
keys.ACCURACY_BASELINE: 2./2,
|
||||
keys.AUC: 0.,
|
||||
keys.AUC_PR: 1.,
|
||||
keys.AUC_PR: 0.74999905,
|
||||
}
|
||||
|
||||
# Assert spec contains expected tensors.
|
||||
@ -1636,7 +1636,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
|
||||
keys.LABEL_MEAN: 2./2,
|
||||
keys.ACCURACY_BASELINE: 2./2,
|
||||
keys.AUC: 0.,
|
||||
keys.AUC_PR: 1.,
|
||||
keys.AUC_PR: 0.75,
|
||||
}
|
||||
|
||||
# Assert predictions, loss, and metrics.
|
||||
@ -1741,7 +1741,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
|
||||
keys.LABEL_MEAN: 2./2,
|
||||
keys.ACCURACY_BASELINE: 2./2,
|
||||
keys.AUC: 0.,
|
||||
keys.AUC_PR: 1.,
|
||||
keys.AUC_PR: 0.74999905,
|
||||
keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 1.,
|
||||
keys.PRECISION_AT_THRESHOLD % thresholds[0]: 1.,
|
||||
keys.RECALL_AT_THRESHOLD % thresholds[0]: 1.,
|
||||
@ -2188,7 +2188,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
|
||||
keys.LABEL_MEAN: expected_label_mean,
|
||||
keys.ACCURACY_BASELINE: 1 - expected_label_mean,
|
||||
keys.AUC: .45454565,
|
||||
keys.AUC_PR: .6737757325172424,
|
||||
keys.AUC_PR: .21923049,
|
||||
}
|
||||
|
||||
# Assert spec contains expected tensors.
|
||||
@ -2487,7 +2487,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
|
||||
# We cannot reliably calculate AUC with only 4 data points, but the
|
||||
# values should not change because of backwards-compatibility.
|
||||
keys.AUC: 0.5222,
|
||||
keys.AUC_PR: 0.7341,
|
||||
keys.AUC_PR: 0.5119,
|
||||
}
|
||||
|
||||
tol = 1e-2
|
||||
|
@ -1342,7 +1342,7 @@ class BaseLinearClassifierEvaluationTest(object):
|
||||
metric_keys.MetricKeys.LABEL_MEAN: 1.,
|
||||
metric_keys.MetricKeys.ACCURACY_BASELINE: 1,
|
||||
metric_keys.MetricKeys.AUC: 0.,
|
||||
metric_keys.MetricKeys.AUC_PR: 1.,
|
||||
metric_keys.MetricKeys.AUC_PR: 0.5,
|
||||
}
|
||||
else:
|
||||
# Multi classes: loss = 1 * -log ( soft_max(logits)[label] )
|
||||
|
@ -1105,9 +1105,9 @@ class AUCTest(test.TestCase):
|
||||
auc, update_op = metrics.auc(labels, predictions, curve='PR')
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3)
|
||||
self.assertAlmostEqual(0.54166, sess.run(update_op), delta=1e-3)
|
||||
|
||||
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
|
||||
self.assertAlmostEqual(0.54166, auc.eval(), delta=1e-3)
|
||||
|
||||
def testAnotherAUCPRSpecialCase(self):
|
||||
with self.test_session() as sess:
|
||||
@ -1119,9 +1119,9 @@ class AUCTest(test.TestCase):
|
||||
auc, update_op = metrics.auc(labels, predictions, curve='PR')
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3)
|
||||
self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3)
|
||||
|
||||
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
|
||||
self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3)
|
||||
|
||||
def testThirdAUCPRSpecialCase(self):
|
||||
with self.test_session() as sess:
|
||||
@ -1133,9 +1133,26 @@ class AUCTest(test.TestCase):
|
||||
auc, update_op = metrics.auc(labels, predictions, curve='PR')
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3)
|
||||
self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3)
|
||||
|
||||
self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3)
|
||||
self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3)
|
||||
|
||||
def testFourthAUCPRSpecialCase(self):
|
||||
# Create the labels and data.
|
||||
labels = np.array([
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 0, 1])
|
||||
predictions = np.array([
|
||||
0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35])
|
||||
|
||||
with self.test_session() as sess:
|
||||
auc, _ = metrics.auc(
|
||||
labels, predictions, curve='PR', num_thresholds=11)
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
# Since this is only approximate, we can't expect a 6 digits match.
|
||||
# Although with higher number of samples/thresholds we should see the
|
||||
# accuracy improving
|
||||
self.assertAlmostEqual(0.0, auc.eval(), delta=0.001)
|
||||
|
||||
def testAllIncorrect(self):
|
||||
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||
@ -1161,16 +1178,16 @@ class AUCTest(test.TestCase):
|
||||
|
||||
self.assertAlmostEqual(1, auc.eval(), 6)
|
||||
|
||||
def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
|
||||
def testRecallOneAndPrecisionOne(self):
|
||||
with self.test_session() as sess:
|
||||
predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
|
||||
labels = array_ops.ones([4])
|
||||
auc, update_op = metrics.auc(labels, predictions, curve='PR')
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(1, sess.run(update_op), 6)
|
||||
self.assertAlmostEqual(0.5, sess.run(update_op), 6)
|
||||
|
||||
self.assertAlmostEqual(1, auc.eval(), 6)
|
||||
self.assertAlmostEqual(0.5, auc.eval(), 6)
|
||||
|
||||
def np_auc(self, predictions, labels, weights):
|
||||
"""Computes the AUC explicitly using Numpy.
|
||||
|
@ -672,7 +672,7 @@ def auc(labels,
|
||||
x = fp_rate
|
||||
y = rec
|
||||
else: # curve == 'PR'.
|
||||
prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
|
||||
prec = math_ops.div(tp, tp + fp + epsilon)
|
||||
x = rec
|
||||
y = prec
|
||||
if summation_method == 'trapezoidal':
|
||||
@ -923,8 +923,8 @@ def mean_per_class_accuracy(labels,
|
||||
weights = array_ops.reshape(weights, [-1])
|
||||
weights = math_ops.to_float(weights)
|
||||
|
||||
is_correct = is_correct * weights
|
||||
ones = ones * weights
|
||||
is_correct *= weights
|
||||
ones *= weights
|
||||
|
||||
update_total_op = state_ops.scatter_add(total, labels, ones)
|
||||
update_count_op = state_ops.scatter_add(count, labels, is_correct)
|
||||
|
Loading…
Reference in New Issue
Block a user