Add streaming_precision_recall_at_equal_thresholds
This helper method computes streaming tp, fp, tn, fp, precision, and recall for the user in a way that exhibits O(T + N) time and space complexity (instead of O(T * N)), where T is the number of thresholds and N is the size of the predictions tensor. Thanks to Frank Chu for the efficient algorithm! PiperOrigin-RevId: 172946073
This commit is contained in:
parent
ccfd9c1e50
commit
ebcae4a5e3
tensorflow/contrib/metrics/python/ops
@ -22,6 +22,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections as collections_lib
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -1076,6 +1078,9 @@ def streaming_curve_points(labels=None,
|
||||
`weights` is not `None` and its shape doesn't match `predictions`, or if
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
|
||||
TODO(chizeng): Consider rewriting this method to make use of logic within the
|
||||
streaming_precision_recall_at_equal_thresholds method (to improve run time).
|
||||
"""
|
||||
with variable_scope.variable_scope(name, 'curve_points',
|
||||
(labels, predictions, weights)):
|
||||
@ -1193,6 +1198,181 @@ def streaming_auc(predictions,
|
||||
name=name)
|
||||
|
||||
|
||||
def streaming_precision_recall_at_equal_thresholds(predictions,
|
||||
labels,
|
||||
num_thresholds=None,
|
||||
weights=None,
|
||||
name=None,
|
||||
use_locking=None):
|
||||
"""A helper method for creating metrics related to precision-recall curves.
|
||||
|
||||
These values are true positives, false negatives, true negatives, false
|
||||
positives, precision, and recall. This function returns a data structure that
|
||||
contains ops within it.
|
||||
|
||||
Unlike _streaming_confusion_matrix_at_thresholds (which exhibits O(T * N)
|
||||
space and run time), this op exhibits O(T + N) space and run time, where T is
|
||||
the number of thresholds and N is the size of the predictions tensor. Hence,
|
||||
it may be advantageous to use this function when `predictions` is big.
|
||||
|
||||
For instance, prefer this method for per-pixel classification tasks, for which
|
||||
the predictions tensor may be very large.
|
||||
|
||||
Each number in `predictions`, a float in `[0, 1]`, is compared with its
|
||||
corresponding label in `labels`, and counts as a single tp/fp/tn/fn value at
|
||||
each threshold. This is then multiplied with `weights` which can be used to
|
||||
reweight certain values, or more commonly used for masking values.
|
||||
|
||||
Args:
|
||||
predictions: A floating point `Tensor` of arbitrary shape and whose values
|
||||
are in the range `[0, 1]`.
|
||||
labels: A bool `Tensor` whose shape matches `predictions`.
|
||||
num_thresholds: Optional; Number of thresholds, evenly distributed in
|
||||
`[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins
|
||||
is 1 less than `num_thresholds`. Using an even `num_thresholds` value
|
||||
instead of an odd one may yield unfriendly edges for bins.
|
||||
weights: Optional; If provided, a `Tensor` that has the same dtype as,
|
||||
and broadcastable to, `predictions`. This tensor is multplied by counts.
|
||||
name: Optional; variable_scope name. If not provided, the string
|
||||
'precision_recall_at_equal_threshold' is used.
|
||||
use_locking: Optional; If True, the op will be protected by a lock.
|
||||
Otherwise, the behavior is undefined, but may exhibit less contention.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
result: A named tuple (See PrecisionRecallData within the implementation of
|
||||
this function) with properties that are variables of shape
|
||||
`[num_thresholds]`. The names of the properties are tp, fp, tn, fn,
|
||||
precision, recall, thresholds.
|
||||
update_op: An op that accumulates values.
|
||||
|
||||
Raises:
|
||||
ValueError: If `predictions` and `labels` have mismatched shapes, or if
|
||||
`weights` is not `None` and its shape doesn't match `predictions`, or if
|
||||
`includes` contains invalid keys.
|
||||
"""
|
||||
# Disable the invalid-name checker so that we can capitalize the name.
|
||||
# pylint: disable=invalid-name
|
||||
PrecisionRecallData = collections_lib.namedtuple(
|
||||
'PrecisionRecallData',
|
||||
['tp', 'fp', 'tn', 'fn', 'precision', 'recall', 'thresholds'])
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
if num_thresholds is None:
|
||||
num_thresholds = 201
|
||||
|
||||
if weights is None:
|
||||
weights = 1.0
|
||||
|
||||
if use_locking is None:
|
||||
use_locking = True
|
||||
|
||||
check_ops.assert_type(labels, dtypes.bool)
|
||||
|
||||
dtype = predictions.dtype
|
||||
with variable_scope.variable_scope(name,
|
||||
'precision_recall_at_equal_thresholds',
|
||||
(labels, predictions, weights)):
|
||||
# Make sure that predictions are within [0.0, 1.0].
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_greater_equal(
|
||||
predictions,
|
||||
math_ops.cast(0.0, dtype=predictions.dtype),
|
||||
message='predictions must be in [0, 1]'),
|
||||
check_ops.assert_less_equal(
|
||||
predictions,
|
||||
math_ops.cast(1.0, dtype=predictions.dtype),
|
||||
message='predictions must be in [0, 1]')
|
||||
]):
|
||||
predictions, labels, weights = _remove_squeezable_dimensions(
|
||||
predictions=predictions, labels=labels, weights=weights)
|
||||
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
|
||||
# We cast to float to ensure we have 0.0 or 1.0.
|
||||
f_labels = math_ops.cast(labels, dtype)
|
||||
|
||||
# Get weighted true/false labels.
|
||||
true_labels = f_labels * weights
|
||||
false_labels = (1.0 - f_labels) * weights
|
||||
|
||||
# Flatten predictions and labels.
|
||||
predictions = array_ops.reshape(predictions, [-1])
|
||||
true_labels = array_ops.reshape(true_labels, [-1])
|
||||
false_labels = array_ops.reshape(false_labels, [-1])
|
||||
|
||||
# To compute TP/FP/TN/FN, we are measuring a binary classifier
|
||||
# C(t) = (predictions >= t)
|
||||
# at each threshold 't'. So we have
|
||||
# TP(t) = sum( C(t) * true_labels )
|
||||
# FP(t) = sum( C(t) * false_labels )
|
||||
#
|
||||
# But, computing C(t) requires computation for each t. To make it fast,
|
||||
# observe that C(t) is a cumulative integral, and so if we have
|
||||
# thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1}
|
||||
# where n = num_thresholds, and if we can compute the bucket function
|
||||
# B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
|
||||
# then we get
|
||||
# C(t_i) = sum( B(j), j >= i )
|
||||
# which is the reversed cumulative sum in tf.cumsum().
|
||||
#
|
||||
# We can compute B(i) efficiently by taking advantage of the fact that
|
||||
# our thresholds are evenly distributed, in that
|
||||
# width = 1.0 / (num_thresholds - 1)
|
||||
# thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
|
||||
# Given a prediction value p, we can map it to its bucket by
|
||||
# bucket_index(p) = floor( p * (num_thresholds - 1) )
|
||||
# so we can use tf.scatter_add() to update the buckets in one pass.
|
||||
#
|
||||
# This implementation exhibits a run time and space complexity of O(T + N),
|
||||
# where T is the number of thresholds and N is the size of predictions.
|
||||
# Metrics that rely on _streaming_confusion_matrix_at_thresholds instead
|
||||
# exhibit a complexity of O(T * N).
|
||||
|
||||
# Compute the bucket indices for each prediction value.
|
||||
bucket_indices = math_ops.cast(
|
||||
math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32)
|
||||
|
||||
with ops.name_scope('variables'):
|
||||
tp_buckets_v = _create_local(
|
||||
'tp_buckets', shape=[num_thresholds], dtype=dtype)
|
||||
fp_buckets_v = _create_local(
|
||||
'fp_buckets', shape=[num_thresholds], dtype=dtype)
|
||||
|
||||
with ops.name_scope('update_op'):
|
||||
update_tp = state_ops.scatter_add(
|
||||
tp_buckets_v, bucket_indices, true_labels, use_locking=use_locking)
|
||||
update_fp = state_ops.scatter_add(
|
||||
fp_buckets_v, bucket_indices, false_labels, use_locking=use_locking)
|
||||
|
||||
# Set up the cumulative sums to compute the actual metrics.
|
||||
tp = math_ops.cumsum(tp_buckets_v, reverse=True, name='tp')
|
||||
fp = math_ops.cumsum(fp_buckets_v, reverse=True, name='fp')
|
||||
# fn = sum(true_labels) - tp
|
||||
# = sum(tp_buckets) - tp
|
||||
# = tp[0] - tp
|
||||
# Similarly,
|
||||
# tn = fp[0] - fp
|
||||
tn = fp[0] - fp
|
||||
fn = tp[0] - tp
|
||||
|
||||
# We use a minimum to prevent division by 0.
|
||||
epsilon = 1e-7
|
||||
precision = tp / math_ops.maximum(epsilon, tp + fp)
|
||||
recall = tp / math_ops.maximum(epsilon, tp + fn)
|
||||
|
||||
result = PrecisionRecallData(
|
||||
tp=tp,
|
||||
fp=fp,
|
||||
tn=tn,
|
||||
fn=fn,
|
||||
precision=precision,
|
||||
recall=recall,
|
||||
thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds))
|
||||
update_op = control_flow_ops.group(update_tp, update_fp)
|
||||
return result, update_op
|
||||
|
||||
|
||||
def streaming_specificity_at_sensitivity(predictions,
|
||||
labels,
|
||||
sensitivity,
|
||||
|
@ -1970,6 +1970,170 @@ class StreamingAUCTest(test.TestCase):
|
||||
self.assertAlmostEqual(expected_auc, auc.eval(), 2)
|
||||
|
||||
|
||||
class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
np.random.seed(1)
|
||||
ops.reset_default_graph()
|
||||
|
||||
def _testResultsEqual(self, expected_dict, gotten_result):
|
||||
"""Tests that 2 results (dicts) represent the same data.
|
||||
|
||||
Args:
|
||||
expected_dict: A dictionary with keys that are the names of properties
|
||||
of PrecisionRecallData and whose values are lists of floats.
|
||||
gotten_result: A PrecisionRecallData object.
|
||||
"""
|
||||
gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()}
|
||||
self.assertItemsEqual(
|
||||
list(expected_dict.keys()), list(gotten_dict.keys()))
|
||||
|
||||
for key, expected_values in expected_dict.items():
|
||||
self.assertAllClose(expected_values, gotten_dict[key])
|
||||
|
||||
def _testCase(self, predictions, labels, expected_result, weights=None):
|
||||
"""Performs a test given a certain scenario of labels, predictions, weights.
|
||||
|
||||
Args:
|
||||
predictions: The predictions tensor. Of type float32.
|
||||
labels: The labels tensor. Of type bool.
|
||||
expected_result: The expected result (dict) that maps to tensors.
|
||||
weights: Optional weights tensor.
|
||||
"""
|
||||
with self.test_session() as sess:
|
||||
predictions_tensor = constant_op.constant(
|
||||
predictions, dtype=dtypes_lib.float32)
|
||||
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool)
|
||||
weights_tensor = None
|
||||
if weights:
|
||||
weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32)
|
||||
gotten_result, update_op = (
|
||||
metric_ops.streaming_precision_recall_at_equal_thresholds(
|
||||
predictions=predictions_tensor,
|
||||
labels=labels_tensor,
|
||||
num_thresholds=3,
|
||||
weights=weights_tensor))
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
|
||||
self._testResultsEqual(expected_result, gotten_result)
|
||||
|
||||
def testVars(self):
|
||||
metric_ops.streaming_precision_recall_at_equal_thresholds(
|
||||
predictions=constant_op.constant([0.42], dtype=dtypes_lib.float32),
|
||||
labels=constant_op.constant([True], dtype=dtypes_lib.bool))
|
||||
_assert_local_variables(
|
||||
self,
|
||||
(
|
||||
'precision_recall_at_equal_thresholds/variables/tp_buckets:0',
|
||||
'precision_recall_at_equal_thresholds/variables/fp_buckets:0'
|
||||
))
|
||||
|
||||
def testVarsWithName(self):
|
||||
metric_ops.streaming_precision_recall_at_equal_thresholds(
|
||||
predictions=constant_op.constant([0.42], dtype=dtypes_lib.float32),
|
||||
labels=constant_op.constant([True], dtype=dtypes_lib.bool),
|
||||
name='foo')
|
||||
_assert_local_variables(
|
||||
self, ('foo/variables/tp_buckets:0', 'foo/variables/fp_buckets:0'))
|
||||
|
||||
def testValuesAreIdempotent(self):
|
||||
predictions = constant_op.constant(
|
||||
np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32)
|
||||
labels = constant_op.constant(
|
||||
np.random.uniform(size=(10, 3)) > 0.5, dtype=dtypes_lib.bool)
|
||||
|
||||
result, update_op = (
|
||||
metric_ops.streaming_precision_recall_at_equal_thresholds(
|
||||
predictions=predictions, labels=labels))
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Run several updates.
|
||||
sess.run(variables.local_variables_initializer())
|
||||
for _ in range(3):
|
||||
sess.run(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_result = {k: value.eval().tolist() for k, value in
|
||||
result._asdict().items()}
|
||||
for _ in range(3):
|
||||
self._testResultsEqual(initial_result, result)
|
||||
|
||||
def testAllTruePositives(self):
|
||||
self._testCase([[1]], [[True]], {
|
||||
'tp': [1, 1, 1],
|
||||
'fp': [0, 0, 0],
|
||||
'tn': [0, 0, 0],
|
||||
'fn': [0, 0, 0],
|
||||
'precision': [1.0, 1.0, 1.0],
|
||||
'recall': [1.0, 1.0, 1.0],
|
||||
'thresholds': [0.0, 0.5, 1.0],
|
||||
})
|
||||
|
||||
def testAllTrueNegatives(self):
|
||||
self._testCase([[0]], [[False]], {
|
||||
'tp': [0, 0, 0],
|
||||
'fp': [1, 0, 0],
|
||||
'tn': [0, 1, 1],
|
||||
'fn': [0, 0, 0],
|
||||
'precision': [0.0, 0.0, 0.0],
|
||||
'recall': [0.0, 0.0, 0.0],
|
||||
'thresholds': [0.0, 0.5, 1.0],
|
||||
})
|
||||
|
||||
def testAllFalsePositives(self):
|
||||
self._testCase([[1]], [[False]], {
|
||||
'tp': [0, 0, 0],
|
||||
'fp': [1, 1, 1],
|
||||
'tn': [0, 0, 0],
|
||||
'fn': [0, 0, 0],
|
||||
'precision': [0.0, 0.0, 0.0],
|
||||
'recall': [0.0, 0.0, 0.0],
|
||||
'thresholds': [0.0, 0.5, 1.0],
|
||||
})
|
||||
|
||||
def testAllFalseNegatives(self):
|
||||
self._testCase([[0]], [[True]], {
|
||||
'tp': [1, 0, 0],
|
||||
'fp': [0, 0, 0],
|
||||
'tn': [0, 0, 0],
|
||||
'fn': [0, 1, 1],
|
||||
'precision': [1.0, 0.0, 0.0],
|
||||
'recall': [1.0, 0.0, 0.0],
|
||||
'thresholds': [0.0, 0.5, 1.0],
|
||||
})
|
||||
|
||||
def testManyValues(self):
|
||||
self._testCase(
|
||||
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]],
|
||||
[[True, False, False, True, True, True]],
|
||||
{
|
||||
'tp': [4, 3, 0],
|
||||
'fp': [2, 0, 0],
|
||||
'tn': [0, 2, 2],
|
||||
'fn': [0, 1, 4],
|
||||
'precision': [2.0 / 3.0, 1.0, 0.0],
|
||||
'recall': [1.0, 0.75, 0.0],
|
||||
'thresholds': [0.0, 0.5, 1.0],
|
||||
})
|
||||
|
||||
def testManyValuesWithWeights(self):
|
||||
self._testCase(
|
||||
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]],
|
||||
[[True, False, False, True, True, True]],
|
||||
{
|
||||
'tp': [1.5, 1.5, 0.0],
|
||||
'fp': [2.5, 0.0, 0.0],
|
||||
'tn': [0.0, 2.5, 2.5],
|
||||
'fn': [0.0, 0.0, 1.5],
|
||||
'precision': [0.375, 1.0, 0.0],
|
||||
'recall': [1.0, 1.0, 0.0],
|
||||
'thresholds': [0.0, 0.5, 1.0],
|
||||
},
|
||||
weights=[0.0, 0.5, 2.0, 0.0, 0.5, 1.0])
|
||||
|
||||
|
||||
class StreamingSpecificityAtSensitivityTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
Loading…
Reference in New Issue
Block a user