Add streaming_precision_recall_at_equal_thresholds

This helper method computes streaming tp, fp, tn, fp, precision, and recall for the user in a way that exhibits O(T + N) time and space complexity (instead of O(T * N)), where T is the number of thresholds and N is the size of the predictions tensor.

Thanks to Frank Chu for the efficient algorithm!

PiperOrigin-RevId: 172946073
This commit is contained in:
A. Unique TensorFlower 2017-10-20 15:55:17 -07:00 committed by TensorFlower Gardener
parent ccfd9c1e50
commit ebcae4a5e3
2 changed files with 344 additions and 0 deletions
tensorflow/contrib/metrics/python/ops

View File

@ -22,6 +22,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections as collections_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -1076,6 +1078,9 @@ def streaming_curve_points(labels=None,
`weights` is not `None` and its shape doesn't match `predictions`, or if
either `metrics_collections` or `updates_collections` are not a list or
tuple.
TODO(chizeng): Consider rewriting this method to make use of logic within the
streaming_precision_recall_at_equal_thresholds method (to improve run time).
"""
with variable_scope.variable_scope(name, 'curve_points',
(labels, predictions, weights)):
@ -1193,6 +1198,181 @@ def streaming_auc(predictions,
name=name)
def streaming_precision_recall_at_equal_thresholds(predictions,
labels,
num_thresholds=None,
weights=None,
name=None,
use_locking=None):
"""A helper method for creating metrics related to precision-recall curves.
These values are true positives, false negatives, true negatives, false
positives, precision, and recall. This function returns a data structure that
contains ops within it.
Unlike _streaming_confusion_matrix_at_thresholds (which exhibits O(T * N)
space and run time), this op exhibits O(T + N) space and run time, where T is
the number of thresholds and N is the size of the predictions tensor. Hence,
it may be advantageous to use this function when `predictions` is big.
For instance, prefer this method for per-pixel classification tasks, for which
the predictions tensor may be very large.
Each number in `predictions`, a float in `[0, 1]`, is compared with its
corresponding label in `labels`, and counts as a single tp/fp/tn/fn value at
each threshold. This is then multiplied with `weights` which can be used to
reweight certain values, or more commonly used for masking values.
Args:
predictions: A floating point `Tensor` of arbitrary shape and whose values
are in the range `[0, 1]`.
labels: A bool `Tensor` whose shape matches `predictions`.
num_thresholds: Optional; Number of thresholds, evenly distributed in
`[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins
is 1 less than `num_thresholds`. Using an even `num_thresholds` value
instead of an odd one may yield unfriendly edges for bins.
weights: Optional; If provided, a `Tensor` that has the same dtype as,
and broadcastable to, `predictions`. This tensor is multplied by counts.
name: Optional; variable_scope name. If not provided, the string
'precision_recall_at_equal_threshold' is used.
use_locking: Optional; If True, the op will be protected by a lock.
Otherwise, the behavior is undefined, but may exhibit less contention.
Defaults to True.
Returns:
result: A named tuple (See PrecisionRecallData within the implementation of
this function) with properties that are variables of shape
`[num_thresholds]`. The names of the properties are tp, fp, tn, fn,
precision, recall, thresholds.
update_op: An op that accumulates values.
Raises:
ValueError: If `predictions` and `labels` have mismatched shapes, or if
`weights` is not `None` and its shape doesn't match `predictions`, or if
`includes` contains invalid keys.
"""
# Disable the invalid-name checker so that we can capitalize the name.
# pylint: disable=invalid-name
PrecisionRecallData = collections_lib.namedtuple(
'PrecisionRecallData',
['tp', 'fp', 'tn', 'fn', 'precision', 'recall', 'thresholds'])
# pylint: enable=invalid-name
if num_thresholds is None:
num_thresholds = 201
if weights is None:
weights = 1.0
if use_locking is None:
use_locking = True
check_ops.assert_type(labels, dtypes.bool)
dtype = predictions.dtype
with variable_scope.variable_scope(name,
'precision_recall_at_equal_thresholds',
(labels, predictions, weights)):
# Make sure that predictions are within [0.0, 1.0].
with ops.control_dependencies([
check_ops.assert_greater_equal(
predictions,
math_ops.cast(0.0, dtype=predictions.dtype),
message='predictions must be in [0, 1]'),
check_ops.assert_less_equal(
predictions,
math_ops.cast(1.0, dtype=predictions.dtype),
message='predictions must be in [0, 1]')
]):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=predictions, labels=labels, weights=weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
# We cast to float to ensure we have 0.0 or 1.0.
f_labels = math_ops.cast(labels, dtype)
# Get weighted true/false labels.
true_labels = f_labels * weights
false_labels = (1.0 - f_labels) * weights
# Flatten predictions and labels.
predictions = array_ops.reshape(predictions, [-1])
true_labels = array_ops.reshape(true_labels, [-1])
false_labels = array_ops.reshape(false_labels, [-1])
# To compute TP/FP/TN/FN, we are measuring a binary classifier
# C(t) = (predictions >= t)
# at each threshold 't'. So we have
# TP(t) = sum( C(t) * true_labels )
# FP(t) = sum( C(t) * false_labels )
#
# But, computing C(t) requires computation for each t. To make it fast,
# observe that C(t) is a cumulative integral, and so if we have
# thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1}
# where n = num_thresholds, and if we can compute the bucket function
# B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
# then we get
# C(t_i) = sum( B(j), j >= i )
# which is the reversed cumulative sum in tf.cumsum().
#
# We can compute B(i) efficiently by taking advantage of the fact that
# our thresholds are evenly distributed, in that
# width = 1.0 / (num_thresholds - 1)
# thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
# Given a prediction value p, we can map it to its bucket by
# bucket_index(p) = floor( p * (num_thresholds - 1) )
# so we can use tf.scatter_add() to update the buckets in one pass.
#
# This implementation exhibits a run time and space complexity of O(T + N),
# where T is the number of thresholds and N is the size of predictions.
# Metrics that rely on _streaming_confusion_matrix_at_thresholds instead
# exhibit a complexity of O(T * N).
# Compute the bucket indices for each prediction value.
bucket_indices = math_ops.cast(
math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32)
with ops.name_scope('variables'):
tp_buckets_v = _create_local(
'tp_buckets', shape=[num_thresholds], dtype=dtype)
fp_buckets_v = _create_local(
'fp_buckets', shape=[num_thresholds], dtype=dtype)
with ops.name_scope('update_op'):
update_tp = state_ops.scatter_add(
tp_buckets_v, bucket_indices, true_labels, use_locking=use_locking)
update_fp = state_ops.scatter_add(
fp_buckets_v, bucket_indices, false_labels, use_locking=use_locking)
# Set up the cumulative sums to compute the actual metrics.
tp = math_ops.cumsum(tp_buckets_v, reverse=True, name='tp')
fp = math_ops.cumsum(fp_buckets_v, reverse=True, name='fp')
# fn = sum(true_labels) - tp
# = sum(tp_buckets) - tp
# = tp[0] - tp
# Similarly,
# tn = fp[0] - fp
tn = fp[0] - fp
fn = tp[0] - tp
# We use a minimum to prevent division by 0.
epsilon = 1e-7
precision = tp / math_ops.maximum(epsilon, tp + fp)
recall = tp / math_ops.maximum(epsilon, tp + fn)
result = PrecisionRecallData(
tp=tp,
fp=fp,
tn=tn,
fn=fn,
precision=precision,
recall=recall,
thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds))
update_op = control_flow_ops.group(update_tp, update_fp)
return result, update_op
def streaming_specificity_at_sensitivity(predictions,
labels,
sensitivity,

View File

@ -1970,6 +1970,170 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(expected_auc, auc.eval(), 2)
class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
def setUp(self):
np.random.seed(1)
ops.reset_default_graph()
def _testResultsEqual(self, expected_dict, gotten_result):
"""Tests that 2 results (dicts) represent the same data.
Args:
expected_dict: A dictionary with keys that are the names of properties
of PrecisionRecallData and whose values are lists of floats.
gotten_result: A PrecisionRecallData object.
"""
gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()}
self.assertItemsEqual(
list(expected_dict.keys()), list(gotten_dict.keys()))
for key, expected_values in expected_dict.items():
self.assertAllClose(expected_values, gotten_dict[key])
def _testCase(self, predictions, labels, expected_result, weights=None):
"""Performs a test given a certain scenario of labels, predictions, weights.
Args:
predictions: The predictions tensor. Of type float32.
labels: The labels tensor. Of type bool.
expected_result: The expected result (dict) that maps to tensors.
weights: Optional weights tensor.
"""
with self.test_session() as sess:
predictions_tensor = constant_op.constant(
predictions, dtype=dtypes_lib.float32)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool)
weights_tensor = None
if weights:
weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32)
gotten_result, update_op = (
metric_ops.streaming_precision_recall_at_equal_thresholds(
predictions=predictions_tensor,
labels=labels_tensor,
num_thresholds=3,
weights=weights_tensor))
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self._testResultsEqual(expected_result, gotten_result)
def testVars(self):
metric_ops.streaming_precision_recall_at_equal_thresholds(
predictions=constant_op.constant([0.42], dtype=dtypes_lib.float32),
labels=constant_op.constant([True], dtype=dtypes_lib.bool))
_assert_local_variables(
self,
(
'precision_recall_at_equal_thresholds/variables/tp_buckets:0',
'precision_recall_at_equal_thresholds/variables/fp_buckets:0'
))
def testVarsWithName(self):
metric_ops.streaming_precision_recall_at_equal_thresholds(
predictions=constant_op.constant([0.42], dtype=dtypes_lib.float32),
labels=constant_op.constant([True], dtype=dtypes_lib.bool),
name='foo')
_assert_local_variables(
self, ('foo/variables/tp_buckets:0', 'foo/variables/fp_buckets:0'))
def testValuesAreIdempotent(self):
predictions = constant_op.constant(
np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32)
labels = constant_op.constant(
np.random.uniform(size=(10, 3)) > 0.5, dtype=dtypes_lib.bool)
result, update_op = (
metric_ops.streaming_precision_recall_at_equal_thresholds(
predictions=predictions, labels=labels))
with self.test_session() as sess:
# Run several updates.
sess.run(variables.local_variables_initializer())
for _ in range(3):
sess.run(update_op)
# Then verify idempotency.
initial_result = {k: value.eval().tolist() for k, value in
result._asdict().items()}
for _ in range(3):
self._testResultsEqual(initial_result, result)
def testAllTruePositives(self):
self._testCase([[1]], [[True]], {
'tp': [1, 1, 1],
'fp': [0, 0, 0],
'tn': [0, 0, 0],
'fn': [0, 0, 0],
'precision': [1.0, 1.0, 1.0],
'recall': [1.0, 1.0, 1.0],
'thresholds': [0.0, 0.5, 1.0],
})
def testAllTrueNegatives(self):
self._testCase([[0]], [[False]], {
'tp': [0, 0, 0],
'fp': [1, 0, 0],
'tn': [0, 1, 1],
'fn': [0, 0, 0],
'precision': [0.0, 0.0, 0.0],
'recall': [0.0, 0.0, 0.0],
'thresholds': [0.0, 0.5, 1.0],
})
def testAllFalsePositives(self):
self._testCase([[1]], [[False]], {
'tp': [0, 0, 0],
'fp': [1, 1, 1],
'tn': [0, 0, 0],
'fn': [0, 0, 0],
'precision': [0.0, 0.0, 0.0],
'recall': [0.0, 0.0, 0.0],
'thresholds': [0.0, 0.5, 1.0],
})
def testAllFalseNegatives(self):
self._testCase([[0]], [[True]], {
'tp': [1, 0, 0],
'fp': [0, 0, 0],
'tn': [0, 0, 0],
'fn': [0, 1, 1],
'precision': [1.0, 0.0, 0.0],
'recall': [1.0, 0.0, 0.0],
'thresholds': [0.0, 0.5, 1.0],
})
def testManyValues(self):
self._testCase(
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]],
[[True, False, False, True, True, True]],
{
'tp': [4, 3, 0],
'fp': [2, 0, 0],
'tn': [0, 2, 2],
'fn': [0, 1, 4],
'precision': [2.0 / 3.0, 1.0, 0.0],
'recall': [1.0, 0.75, 0.0],
'thresholds': [0.0, 0.5, 1.0],
})
def testManyValuesWithWeights(self):
self._testCase(
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]],
[[True, False, False, True, True, True]],
{
'tp': [1.5, 1.5, 0.0],
'fp': [2.5, 0.0, 0.0],
'tn': [0.0, 2.5, 2.5],
'fn': [0.0, 0.0, 1.5],
'precision': [0.375, 1.0, 0.0],
'recall': [1.0, 1.0, 0.0],
'thresholds': [0.0, 0.5, 1.0],
},
weights=[0.0, 0.5, 2.0, 0.0, 0.5, 1.0])
class StreamingSpecificityAtSensitivityTest(test.TestCase):
def setUp(self):