Adds streaming_dynamic_auc to Tensorflow contrib metrics. This metric differs from streaming_auc because it uses every prediction as a threshold rather than linearly spaced fixed thresholds.
PiperOrigin-RevId: 175217002
This commit is contained in:
parent
f3f85e9aa0
commit
b11a790328
@ -27,6 +27,7 @@ See the @{$python/contrib.metrics} guide.
|
||||
@@streaming_false_negative_rate
|
||||
@@streaming_false_negative_rate_at_thresholds
|
||||
@@streaming_auc
|
||||
@@streaming_dynamic_auc
|
||||
@@streaming_curve_points
|
||||
@@streaming_recall_at_k
|
||||
@@streaming_mean_absolute_error
|
||||
@ -88,6 +89,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_dynamic_auc
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate_at_thresholds
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
|
||||
|
@ -1178,6 +1178,154 @@ def streaming_auc(predictions,
|
||||
name=name)
|
||||
|
||||
|
||||
def _compute_dynamic_auc(labels, predictions, curve='ROC'):
|
||||
"""Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
|
||||
|
||||
Computes the area under the ROC or PR curve using each prediction as a
|
||||
threshold. This could be slow for large batches, but has the advantage of not
|
||||
having its results degrade depending on the distribution of predictions.
|
||||
|
||||
Args:
|
||||
labels: A `Tensor` of ground truth labels with the same shape as
|
||||
`predictions` with values of 0 or 1 and type `int64`.
|
||||
predictions: A 1-D `Tensor` of predictions whose values are `float64`.
|
||||
curve: The name of the curve to be computed, 'ROC' for the Receiving
|
||||
Operating Characteristic or 'PR' for the Precision-Recall curve.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` containing the area-under-curve value for the input.
|
||||
"""
|
||||
# Count the total number of positive and negative labels in the input.
|
||||
size = array_ops.size(predictions)
|
||||
total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32)
|
||||
|
||||
def continue_computing_dynamic_auc():
|
||||
"""Continues dynamic auc computation, entered if labels are not all equal.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` containing the area-under-curve value.
|
||||
"""
|
||||
# Sort the predictions descending, and the corresponding labels as well.
|
||||
ordered_predictions, indices = nn.top_k(predictions, k=size)
|
||||
ordered_labels = array_ops.gather(labels, indices)
|
||||
|
||||
# Get the counts of the unique ordered predictions.
|
||||
_, _, counts = array_ops.unique_with_counts(ordered_predictions)
|
||||
|
||||
# Compute the indices of the split points between different predictions.
|
||||
splits = math_ops.cast(
|
||||
array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32)
|
||||
|
||||
# Count the positives to the left of the split indices.
|
||||
positives = math_ops.cast(
|
||||
array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]),
|
||||
dtypes.int32)
|
||||
true_positives = array_ops.gather(positives, splits)
|
||||
if curve == 'ROC':
|
||||
# Count the negatives to the left of every split point and the total
|
||||
# number of negatives for computing the FPR.
|
||||
false_positives = math_ops.subtract(splits, true_positives)
|
||||
total_negative = size - total_positive
|
||||
x_axis_values = math_ops.truediv(false_positives, total_negative)
|
||||
y_axis_values = math_ops.truediv(true_positives, total_positive)
|
||||
elif curve == 'PR':
|
||||
x_axis_values = math_ops.truediv(true_positives, total_positive)
|
||||
# For conformance, set precision to 1 when the number of positive
|
||||
# classifications is 0.
|
||||
y_axis_values = array_ops.where(
|
||||
math_ops.greater(splits, 0),
|
||||
math_ops.truediv(true_positives, splits),
|
||||
array_ops.ones_like(true_positives, dtype=dtypes.float64))
|
||||
|
||||
# Calculate trapezoid areas.
|
||||
heights = math_ops.add(y_axis_values[1:], y_axis_values[:-1]) / 2.0
|
||||
widths = math_ops.abs(
|
||||
math_ops.subtract(x_axis_values[1:], x_axis_values[:-1]))
|
||||
return math_ops.reduce_sum(math_ops.multiply(heights, widths))
|
||||
|
||||
# If all the labels are the same, AUC isn't well-defined (but raising an
|
||||
# exception seems excessive) so we return 0, otherwise we finish computing.
|
||||
return control_flow_ops.cond(
|
||||
math_ops.logical_or(
|
||||
math_ops.equal(total_positive, 0),
|
||||
math_ops.equal(total_positive, size)
|
||||
),
|
||||
true_fn=lambda: array_ops.constant(0, dtypes.float64),
|
||||
false_fn=continue_computing_dynamic_auc)
|
||||
|
||||
|
||||
def streaming_dynamic_auc(labels,
|
||||
predictions,
|
||||
curve='ROC',
|
||||
metrics_collections=(),
|
||||
updates_collections=(),
|
||||
name=None):
|
||||
"""Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
|
||||
|
||||
USAGE NOTE: this approach requires storing all of the predictions and labels
|
||||
for a single evaluation in memory, so it may not be usable when the evaluation
|
||||
batch size and/or the number of evaluation steps is very large.
|
||||
|
||||
Computes the area under the ROC or PR curve using each prediction as a
|
||||
threshold. This has the advantage of being resilient to the distribution of
|
||||
predictions by aggregating across batches, accumulating labels and predictions
|
||||
and performing the final calculation using all of the concatenated values.
|
||||
|
||||
Args:
|
||||
labels: A `Tensor` of ground truth labels with the same shape as `labels`
|
||||
and with values of 0 or 1 whose values are castable to `int64`.
|
||||
predictions: A `Tensor` of predictions whose values are castable to
|
||||
`float64`. Will be flattened into a 1-D `Tensor`.
|
||||
curve: The name of the curve for which to compute AUC, 'ROC' for the
|
||||
Receiving Operating Characteristic or 'PR' for the Precision-Recall curve.
|
||||
metrics_collections: An optional iterable of collections that `auc` should
|
||||
be added to.
|
||||
updates_collections: An optional iterable of collections that `update_op`
|
||||
should be added to.
|
||||
name: An optional name for the variable_scope that contains the metric
|
||||
variables.
|
||||
|
||||
Returns:
|
||||
auc: A scalar `Tensor` containing the current area-under-curve value.
|
||||
update_op: An operation that concatenates the input labels and predictions
|
||||
to the accumulated values.
|
||||
|
||||
Raises:
|
||||
ValueError: If `labels` and `predictions` have mismatched shapes or if
|
||||
`curve` isn't a recognized curve type.
|
||||
"""
|
||||
|
||||
if curve not in ['PR', 'ROC']:
|
||||
raise ValueError('curve must be either ROC or PR, %s unknown' % curve)
|
||||
|
||||
with variable_scope.variable_scope(name, default_name='dynamic_auc'):
|
||||
labels.get_shape().assert_is_compatible_with(predictions.get_shape())
|
||||
predictions = array_ops.reshape(
|
||||
math_ops.cast(predictions, dtypes.float64), [-1])
|
||||
labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1])
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_greater_equal(
|
||||
labels,
|
||||
array_ops.zeros_like(labels, dtypes.int64),
|
||||
message='labels must be 0 or 1, at least one is <0'),
|
||||
check_ops.assert_less_equal(
|
||||
labels,
|
||||
array_ops.ones_like(labels, dtypes.int64),
|
||||
message='labels must be 0 or 1, at least one is >1')
|
||||
]):
|
||||
preds_accum, update_preds = streaming_concat(predictions,
|
||||
name='concat_preds')
|
||||
labels_accum, update_labels = streaming_concat(labels,
|
||||
name='concat_labels')
|
||||
update_op = control_flow_ops.group(update_labels, update_preds)
|
||||
auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve)
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, auc)
|
||||
return auc, update_op
|
||||
|
||||
|
||||
def streaming_precision_recall_at_equal_thresholds(predictions,
|
||||
labels,
|
||||
num_thresholds=None,
|
||||
@ -3285,6 +3433,7 @@ __all__ = [
|
||||
'streaming_accuracy',
|
||||
'streaming_auc',
|
||||
'streaming_curve_points',
|
||||
'streaming_dynamic_auc',
|
||||
'streaming_false_negative_rate',
|
||||
'streaming_false_negative_rate_at_thresholds',
|
||||
'streaming_false_negatives',
|
||||
|
@ -1708,6 +1708,34 @@ class StreamingCurvePointsTest(test.TestCase):
|
||||
[[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
|
||||
|
||||
|
||||
def _np_auc(predictions, labels, weights=None):
|
||||
"""Computes the AUC explicitly using Numpy.
|
||||
|
||||
Args:
|
||||
predictions: an ndarray with shape [N].
|
||||
labels: an ndarray with shape [N].
|
||||
weights: an ndarray with shape [N].
|
||||
|
||||
Returns:
|
||||
the area under the ROC curve.
|
||||
"""
|
||||
if weights is None:
|
||||
weights = np.ones(np.size(predictions))
|
||||
is_positive = labels > 0
|
||||
num_positives = np.sum(weights[is_positive])
|
||||
num_negatives = np.sum(weights[~is_positive])
|
||||
|
||||
# Sort descending:
|
||||
inds = np.argsort(-predictions)
|
||||
|
||||
sorted_labels = labels[inds]
|
||||
sorted_weights = weights[inds]
|
||||
is_positive = sorted_labels > 0
|
||||
|
||||
tp = np.cumsum(sorted_weights * is_positive) / num_positives
|
||||
return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
|
||||
|
||||
|
||||
class StreamingAUCTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -1896,33 +1924,6 @@ class StreamingAUCTest(test.TestCase):
|
||||
|
||||
self.assertAlmostEqual(1, auc.eval(), 6)
|
||||
|
||||
def np_auc(self, predictions, labels, weights):
|
||||
"""Computes the AUC explicitly using Numpy.
|
||||
|
||||
Args:
|
||||
predictions: an ndarray with shape [N].
|
||||
labels: an ndarray with shape [N].
|
||||
weights: an ndarray with shape [N].
|
||||
|
||||
Returns:
|
||||
the area under the ROC curve.
|
||||
"""
|
||||
if weights is None:
|
||||
weights = np.ones(np.size(predictions))
|
||||
is_positive = labels > 0
|
||||
num_positives = np.sum(weights[is_positive])
|
||||
num_negatives = np.sum(weights[~is_positive])
|
||||
|
||||
# Sort descending:
|
||||
inds = np.argsort(-predictions)
|
||||
|
||||
sorted_labels = labels[inds]
|
||||
sorted_weights = weights[inds]
|
||||
is_positive = sorted_labels > 0
|
||||
|
||||
tp = np.cumsum(sorted_weights * is_positive) / num_positives
|
||||
return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
|
||||
|
||||
def testWithMultipleUpdates(self):
|
||||
num_samples = 1000
|
||||
batch_size = 10
|
||||
@ -1945,7 +1946,7 @@ class StreamingAUCTest(test.TestCase):
|
||||
|
||||
for weights in (None, np.ones(num_samples), np.random.exponential(
|
||||
scale=1.0, size=num_samples)):
|
||||
expected_auc = self.np_auc(predictions, labels, weights)
|
||||
expected_auc = _np_auc(predictions, labels, weights)
|
||||
|
||||
with self.test_session() as sess:
|
||||
enqueue_ops = [[] for i in range(num_batches)]
|
||||
@ -1974,6 +1975,211 @@ class StreamingAUCTest(test.TestCase):
|
||||
self.assertAlmostEqual(expected_auc, auc.eval(), 2)
|
||||
|
||||
|
||||
class StreamingDynamicAUCTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(StreamingDynamicAUCTest, self).setUp()
|
||||
np.random.seed(1)
|
||||
ops.reset_default_graph()
|
||||
|
||||
def testUnknownCurve(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'curve must be either ROC or PR, TEST_CURVE unknown'):
|
||||
metrics.streaming_dynamic_auc(labels=array_ops.ones((10, 1)),
|
||||
predictions=array_ops.ones((10, 1)),
|
||||
curve='TEST_CURVE')
|
||||
|
||||
def testVars(self):
|
||||
metrics.streaming_dynamic_auc(
|
||||
labels=array_ops.ones((10, 1)), predictions=array_ops.ones((10, 1)))
|
||||
_assert_metric_variables(self, ['dynamic_auc/concat_labels/array:0',
|
||||
'dynamic_auc/concat_labels/size:0',
|
||||
'dynamic_auc/concat_preds/array:0',
|
||||
'dynamic_auc/concat_preds/size:0'])
|
||||
|
||||
def testMetricsCollection(self):
|
||||
my_collection_name = '__metrics__'
|
||||
auc, _ = metrics.streaming_dynamic_auc(
|
||||
labels=array_ops.ones((10, 1)),
|
||||
predictions=array_ops.ones((10, 1)),
|
||||
metrics_collections=[my_collection_name])
|
||||
self.assertEqual(ops.get_collection(my_collection_name), [auc])
|
||||
|
||||
def testUpdatesCollection(self):
|
||||
my_collection_name = '__updates__'
|
||||
_, update_op = metrics.streaming_dynamic_auc(
|
||||
labels=array_ops.ones((10, 1)),
|
||||
predictions=array_ops.ones((10, 1)),
|
||||
updates_collections=[my_collection_name])
|
||||
self.assertEqual(ops.get_collection(my_collection_name), [update_op])
|
||||
|
||||
def testValueTensorIsIdempotent(self):
|
||||
predictions = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
|
||||
labels = random_ops.random_uniform(
|
||||
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
|
||||
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||
with self.test_session() as sess:
|
||||
sess.run(variables.local_variables_initializer())
|
||||
# Run several updates.
|
||||
for _ in xrange(10):
|
||||
sess.run(update_op)
|
||||
# Then verify idempotency.
|
||||
initial_auc = auc.eval()
|
||||
for _ in xrange(10):
|
||||
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
|
||||
|
||||
def testAllLabelsOnes(self):
|
||||
with self.test_session() as sess:
|
||||
predictions = constant_op.constant([1., 1., 1.])
|
||||
labels = constant_op.constant([1, 1, 1])
|
||||
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
self.assertEqual(0, auc.eval())
|
||||
|
||||
def testAllLabelsZeros(self):
|
||||
with self.test_session() as sess:
|
||||
predictions = constant_op.constant([1., 1., 1.])
|
||||
labels = constant_op.constant([0, 0, 0])
|
||||
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
self.assertEqual(0, auc.eval())
|
||||
|
||||
def testNonZeroOnePredictions(self):
|
||||
with self.test_session() as sess:
|
||||
predictions = constant_op.constant([2.5, -2.5, 2.5, -2.5],
|
||||
dtype=dtypes_lib.float32)
|
||||
labels = constant_op.constant([1, 0, 1, 0])
|
||||
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
self.assertAlmostEqual(auc.eval(), 1.0)
|
||||
|
||||
def testAllCorrect(self):
|
||||
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||
with self.test_session() as sess:
|
||||
predictions = constant_op.constant(inputs)
|
||||
labels = constant_op.constant(inputs)
|
||||
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
self.assertEqual(1, auc.eval())
|
||||
|
||||
def testSomeCorrect(self):
|
||||
with self.test_session() as sess:
|
||||
predictions = constant_op.constant([1, 0, 1, 0])
|
||||
labels = constant_op.constant([0, 1, 1, 0])
|
||||
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
self.assertAlmostEqual(0.5, auc.eval())
|
||||
|
||||
def testAllIncorrect(self):
|
||||
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||
with self.test_session() as sess:
|
||||
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
|
||||
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
|
||||
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
self.assertAlmostEqual(0, auc.eval())
|
||||
|
||||
def testExceptionOnIncompatibleShapes(self):
|
||||
with self.test_session() as sess:
|
||||
predictions = array_ops.ones([5])
|
||||
labels = array_ops.zeros([6])
|
||||
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
|
||||
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
|
||||
def testExceptionOnGreaterThanOneLabel(self):
|
||||
with self.test_session() as sess:
|
||||
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
|
||||
labels = constant_op.constant([2, 1, 0])
|
||||
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
with self.assertRaisesRegexp(
|
||||
errors_impl.InvalidArgumentError,
|
||||
'.*labels must be 0 or 1, at least one is >1.*'):
|
||||
sess.run(update_op)
|
||||
|
||||
def testExceptionOnNegativeLabel(self):
|
||||
with self.test_session() as sess:
|
||||
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
|
||||
labels = constant_op.constant([1, 0, -1])
|
||||
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
with self.assertRaisesRegexp(
|
||||
errors_impl.InvalidArgumentError,
|
||||
'.*labels must be 0 or 1, at least one is <0.*'):
|
||||
sess.run(update_op)
|
||||
|
||||
def testWithMultipleUpdates(self):
|
||||
batch_size = 10
|
||||
num_batches = 100
|
||||
labels = np.array([])
|
||||
predictions = np.array([])
|
||||
tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32),
|
||||
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
||||
dtype=dtypes_lib.int32)
|
||||
tf_predictions = variables.Variable(
|
||||
array_ops.ones(batch_size),
|
||||
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
||||
dtype=dtypes_lib.float32)
|
||||
auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions)
|
||||
with self.test_session() as sess:
|
||||
sess.run(variables.local_variables_initializer())
|
||||
for _ in xrange(num_batches):
|
||||
new_labels = np.random.randint(0, 2, size=batch_size)
|
||||
noise = np.random.normal(0.0, scale=0.2, size=batch_size)
|
||||
new_predictions = 0.4 + 0.2 * new_labels + noise
|
||||
labels = np.concatenate([labels, new_labels])
|
||||
predictions = np.concatenate([predictions, new_predictions])
|
||||
sess.run(tf_labels.assign(new_labels))
|
||||
sess.run(tf_predictions.assign(new_predictions))
|
||||
sess.run(update_op)
|
||||
expected_auc = _np_auc(predictions, labels)
|
||||
self.assertAlmostEqual(expected_auc, auc.eval())
|
||||
|
||||
def testAUCPRReverseIncreasingPredictions(self):
|
||||
with self.test_session() as sess:
|
||||
predictions = constant_op.constant(
|
||||
[0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32)
|
||||
labels = constant_op.constant([0, 0, 1, 1])
|
||||
auc, update_op = metrics.streaming_dynamic_auc(
|
||||
labels, predictions, curve='PR')
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5)
|
||||
|
||||
def testAUCPRJumbledPredictions(self):
|
||||
with self.test_session() as sess:
|
||||
predictions = constant_op.constant(
|
||||
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32)
|
||||
labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1])
|
||||
auc, update_op = metrics.streaming_dynamic_auc(
|
||||
labels, predictions, curve='PR')
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6)
|
||||
|
||||
def testAUCPRPredictionsLessThanHalf(self):
|
||||
with self.test_session() as sess:
|
||||
predictions = constant_op.constant(
|
||||
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
|
||||
shape=(1, 7),
|
||||
dtype=dtypes_lib.float32)
|
||||
labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
|
||||
auc, update_op = metrics.streaming_dynamic_auc(
|
||||
labels, predictions, curve='PR')
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5)
|
||||
|
||||
|
||||
class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
Loading…
Reference in New Issue
Block a user