diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 09485c4fa2a..5a4c0c43585 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -22,6 +22,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections as collections_lib + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -1076,6 +1078,9 @@ def streaming_curve_points(labels=None, `weights` is not `None` and its shape doesn't match `predictions`, or if either `metrics_collections` or `updates_collections` are not a list or tuple. + + TODO(chizeng): Consider rewriting this method to make use of logic within the + streaming_precision_recall_at_equal_thresholds method (to improve run time). """ with variable_scope.variable_scope(name, 'curve_points', (labels, predictions, weights)): @@ -1193,6 +1198,181 @@ def streaming_auc(predictions, name=name) +def streaming_precision_recall_at_equal_thresholds(predictions, + labels, + num_thresholds=None, + weights=None, + name=None, + use_locking=None): + """A helper method for creating metrics related to precision-recall curves. + + These values are true positives, false negatives, true negatives, false + positives, precision, and recall. This function returns a data structure that + contains ops within it. + + Unlike _streaming_confusion_matrix_at_thresholds (which exhibits O(T * N) + space and run time), this op exhibits O(T + N) space and run time, where T is + the number of thresholds and N is the size of the predictions tensor. Hence, + it may be advantageous to use this function when `predictions` is big. + + For instance, prefer this method for per-pixel classification tasks, for which + the predictions tensor may be very large. + + Each number in `predictions`, a float in `[0, 1]`, is compared with its + corresponding label in `labels`, and counts as a single tp/fp/tn/fn value at + each threshold. This is then multiplied with `weights` which can be used to + reweight certain values, or more commonly used for masking values. + + Args: + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + labels: A bool `Tensor` whose shape matches `predictions`. + num_thresholds: Optional; Number of thresholds, evenly distributed in + `[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins + is 1 less than `num_thresholds`. Using an even `num_thresholds` value + instead of an odd one may yield unfriendly edges for bins. + weights: Optional; If provided, a `Tensor` that has the same dtype as, + and broadcastable to, `predictions`. This tensor is multplied by counts. + name: Optional; variable_scope name. If not provided, the string + 'precision_recall_at_equal_threshold' is used. + use_locking: Optional; If True, the op will be protected by a lock. + Otherwise, the behavior is undefined, but may exhibit less contention. + Defaults to True. + + Returns: + result: A named tuple (See PrecisionRecallData within the implementation of + this function) with properties that are variables of shape + `[num_thresholds]`. The names of the properties are tp, fp, tn, fn, + precision, recall, thresholds. + update_op: An op that accumulates values. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + `includes` contains invalid keys. + """ + # Disable the invalid-name checker so that we can capitalize the name. + # pylint: disable=invalid-name + PrecisionRecallData = collections_lib.namedtuple( + 'PrecisionRecallData', + ['tp', 'fp', 'tn', 'fn', 'precision', 'recall', 'thresholds']) + # pylint: enable=invalid-name + + if num_thresholds is None: + num_thresholds = 201 + + if weights is None: + weights = 1.0 + + if use_locking is None: + use_locking = True + + check_ops.assert_type(labels, dtypes.bool) + + dtype = predictions.dtype + with variable_scope.variable_scope(name, + 'precision_recall_at_equal_thresholds', + (labels, predictions, weights)): + # Make sure that predictions are within [0.0, 1.0]. + with ops.control_dependencies([ + check_ops.assert_greater_equal( + predictions, + math_ops.cast(0.0, dtype=predictions.dtype), + message='predictions must be in [0, 1]'), + check_ops.assert_less_equal( + predictions, + math_ops.cast(1.0, dtype=predictions.dtype), + message='predictions must be in [0, 1]') + ]): + predictions, labels, weights = _remove_squeezable_dimensions( + predictions=predictions, labels=labels, weights=weights) + + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + + # We cast to float to ensure we have 0.0 or 1.0. + f_labels = math_ops.cast(labels, dtype) + + # Get weighted true/false labels. + true_labels = f_labels * weights + false_labels = (1.0 - f_labels) * weights + + # Flatten predictions and labels. + predictions = array_ops.reshape(predictions, [-1]) + true_labels = array_ops.reshape(true_labels, [-1]) + false_labels = array_ops.reshape(false_labels, [-1]) + + # To compute TP/FP/TN/FN, we are measuring a binary classifier + # C(t) = (predictions >= t) + # at each threshold 't'. So we have + # TP(t) = sum( C(t) * true_labels ) + # FP(t) = sum( C(t) * false_labels ) + # + # But, computing C(t) requires computation for each t. To make it fast, + # observe that C(t) is a cumulative integral, and so if we have + # thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} + # where n = num_thresholds, and if we can compute the bucket function + # B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) + # then we get + # C(t_i) = sum( B(j), j >= i ) + # which is the reversed cumulative sum in tf.cumsum(). + # + # We can compute B(i) efficiently by taking advantage of the fact that + # our thresholds are evenly distributed, in that + # width = 1.0 / (num_thresholds - 1) + # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] + # Given a prediction value p, we can map it to its bucket by + # bucket_index(p) = floor( p * (num_thresholds - 1) ) + # so we can use tf.scatter_add() to update the buckets in one pass. + # + # This implementation exhibits a run time and space complexity of O(T + N), + # where T is the number of thresholds and N is the size of predictions. + # Metrics that rely on _streaming_confusion_matrix_at_thresholds instead + # exhibit a complexity of O(T * N). + + # Compute the bucket indices for each prediction value. + bucket_indices = math_ops.cast( + math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32) + + with ops.name_scope('variables'): + tp_buckets_v = _create_local( + 'tp_buckets', shape=[num_thresholds], dtype=dtype) + fp_buckets_v = _create_local( + 'fp_buckets', shape=[num_thresholds], dtype=dtype) + + with ops.name_scope('update_op'): + update_tp = state_ops.scatter_add( + tp_buckets_v, bucket_indices, true_labels, use_locking=use_locking) + update_fp = state_ops.scatter_add( + fp_buckets_v, bucket_indices, false_labels, use_locking=use_locking) + + # Set up the cumulative sums to compute the actual metrics. + tp = math_ops.cumsum(tp_buckets_v, reverse=True, name='tp') + fp = math_ops.cumsum(fp_buckets_v, reverse=True, name='fp') + # fn = sum(true_labels) - tp + # = sum(tp_buckets) - tp + # = tp[0] - tp + # Similarly, + # tn = fp[0] - fp + tn = fp[0] - fp + fn = tp[0] - tp + + # We use a minimum to prevent division by 0. + epsilon = 1e-7 + precision = tp / math_ops.maximum(epsilon, tp + fp) + recall = tp / math_ops.maximum(epsilon, tp + fn) + + result = PrecisionRecallData( + tp=tp, + fp=fp, + tn=tn, + fn=fn, + precision=precision, + recall=recall, + thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds)) + update_op = control_flow_ops.group(update_tp, update_fp) + return result, update_op + + def streaming_specificity_at_sensitivity(predictions, labels, sensitivity, diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index f288fceef6c..f24bec7f115 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1970,6 +1970,170 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(expected_auc, auc.eval(), 2) +class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def _testResultsEqual(self, expected_dict, gotten_result): + """Tests that 2 results (dicts) represent the same data. + + Args: + expected_dict: A dictionary with keys that are the names of properties + of PrecisionRecallData and whose values are lists of floats. + gotten_result: A PrecisionRecallData object. + """ + gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()} + self.assertItemsEqual( + list(expected_dict.keys()), list(gotten_dict.keys())) + + for key, expected_values in expected_dict.items(): + self.assertAllClose(expected_values, gotten_dict[key]) + + def _testCase(self, predictions, labels, expected_result, weights=None): + """Performs a test given a certain scenario of labels, predictions, weights. + + Args: + predictions: The predictions tensor. Of type float32. + labels: The labels tensor. Of type bool. + expected_result: The expected result (dict) that maps to tensors. + weights: Optional weights tensor. + """ + with self.test_session() as sess: + predictions_tensor = constant_op.constant( + predictions, dtype=dtypes_lib.float32) + labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) + weights_tensor = None + if weights: + weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32) + gotten_result, update_op = ( + metric_ops.streaming_precision_recall_at_equal_thresholds( + predictions=predictions_tensor, + labels=labels_tensor, + num_thresholds=3, + weights=weights_tensor)) + + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + + self._testResultsEqual(expected_result, gotten_result) + + def testVars(self): + metric_ops.streaming_precision_recall_at_equal_thresholds( + predictions=constant_op.constant([0.42], dtype=dtypes_lib.float32), + labels=constant_op.constant([True], dtype=dtypes_lib.bool)) + _assert_local_variables( + self, + ( + 'precision_recall_at_equal_thresholds/variables/tp_buckets:0', + 'precision_recall_at_equal_thresholds/variables/fp_buckets:0' + )) + + def testVarsWithName(self): + metric_ops.streaming_precision_recall_at_equal_thresholds( + predictions=constant_op.constant([0.42], dtype=dtypes_lib.float32), + labels=constant_op.constant([True], dtype=dtypes_lib.bool), + name='foo') + _assert_local_variables( + self, ('foo/variables/tp_buckets:0', 'foo/variables/fp_buckets:0')) + + def testValuesAreIdempotent(self): + predictions = constant_op.constant( + np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32) + labels = constant_op.constant( + np.random.uniform(size=(10, 3)) > 0.5, dtype=dtypes_lib.bool) + + result, update_op = ( + metric_ops.streaming_precision_recall_at_equal_thresholds( + predictions=predictions, labels=labels)) + + with self.test_session() as sess: + # Run several updates. + sess.run(variables.local_variables_initializer()) + for _ in range(3): + sess.run(update_op) + + # Then verify idempotency. + initial_result = {k: value.eval().tolist() for k, value in + result._asdict().items()} + for _ in range(3): + self._testResultsEqual(initial_result, result) + + def testAllTruePositives(self): + self._testCase([[1]], [[True]], { + 'tp': [1, 1, 1], + 'fp': [0, 0, 0], + 'tn': [0, 0, 0], + 'fn': [0, 0, 0], + 'precision': [1.0, 1.0, 1.0], + 'recall': [1.0, 1.0, 1.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testAllTrueNegatives(self): + self._testCase([[0]], [[False]], { + 'tp': [0, 0, 0], + 'fp': [1, 0, 0], + 'tn': [0, 1, 1], + 'fn': [0, 0, 0], + 'precision': [0.0, 0.0, 0.0], + 'recall': [0.0, 0.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testAllFalsePositives(self): + self._testCase([[1]], [[False]], { + 'tp': [0, 0, 0], + 'fp': [1, 1, 1], + 'tn': [0, 0, 0], + 'fn': [0, 0, 0], + 'precision': [0.0, 0.0, 0.0], + 'recall': [0.0, 0.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testAllFalseNegatives(self): + self._testCase([[0]], [[True]], { + 'tp': [1, 0, 0], + 'fp': [0, 0, 0], + 'tn': [0, 0, 0], + 'fn': [0, 1, 1], + 'precision': [1.0, 0.0, 0.0], + 'recall': [1.0, 0.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testManyValues(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], + { + 'tp': [4, 3, 0], + 'fp': [2, 0, 0], + 'tn': [0, 2, 2], + 'fn': [0, 1, 4], + 'precision': [2.0 / 3.0, 1.0, 0.0], + 'recall': [1.0, 0.75, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testManyValuesWithWeights(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], + { + 'tp': [1.5, 1.5, 0.0], + 'fp': [2.5, 0.0, 0.0], + 'tn': [0.0, 2.5, 2.5], + 'fn': [0.0, 0.0, 1.5], + 'precision': [0.375, 1.0, 0.0], + 'recall': [1.0, 1.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }, + weights=[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]) + + class StreamingSpecificityAtSensitivityTest(test.TestCase): def setUp(self):