From ebcae4a5e3bf5c840d73a0d90f1b5bf01a68f82c Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 20 Oct 2017 15:55:17 -0700
Subject: [PATCH] Add streaming_precision_recall_at_equal_thresholds

This helper method computes streaming tp, fp, tn, fp, precision, and recall for the user in a way that exhibits O(T + N) time and space complexity (instead of O(T * N)), where T is the number of thresholds and N is the size of the predictions tensor.

Thanks to Frank Chu for the efficient algorithm!

PiperOrigin-RevId: 172946073
---
 .../contrib/metrics/python/ops/metric_ops.py  | 180 ++++++++++++++++++
 .../metrics/python/ops/metric_ops_test.py     | 164 ++++++++++++++++
 2 files changed, 344 insertions(+)

diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 09485c4fa2a..5a4c0c43585 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -22,6 +22,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections as collections_lib
+
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
@@ -1076,6 +1078,9 @@ def streaming_curve_points(labels=None,
       `weights` is not `None` and its shape doesn't match `predictions`, or if
       either `metrics_collections` or `updates_collections` are not a list or
       tuple.
+
+  TODO(chizeng): Consider rewriting this method to make use of logic within the
+  streaming_precision_recall_at_equal_thresholds method (to improve run time).
   """
   with variable_scope.variable_scope(name, 'curve_points',
                                      (labels, predictions, weights)):
@@ -1193,6 +1198,181 @@ def streaming_auc(predictions,
       name=name)
 
 
+def streaming_precision_recall_at_equal_thresholds(predictions,
+                                                   labels,
+                                                   num_thresholds=None,
+                                                   weights=None,
+                                                   name=None,
+                                                   use_locking=None):
+  """A helper method for creating metrics related to precision-recall curves.
+
+  These values are true positives, false negatives, true negatives, false
+  positives, precision, and recall. This function returns a data structure that
+  contains ops within it.
+
+  Unlike _streaming_confusion_matrix_at_thresholds (which exhibits O(T * N)
+  space and run time), this op exhibits O(T + N) space and run time, where T is
+  the number of thresholds and N is the size of the predictions tensor. Hence,
+  it may be advantageous to use this function when `predictions` is big.
+
+  For instance, prefer this method for per-pixel classification tasks, for which
+  the predictions tensor may be very large.
+
+  Each number in `predictions`, a float in `[0, 1]`, is compared with its
+  corresponding label in `labels`, and counts as a single tp/fp/tn/fn value at
+  each threshold. This is then multiplied with `weights` which can be used to
+  reweight certain values, or more commonly used for masking values.
+
+  Args:
+    predictions: A floating point `Tensor` of arbitrary shape and whose values
+      are in the range `[0, 1]`.
+    labels: A bool `Tensor` whose shape matches `predictions`.
+    num_thresholds: Optional; Number of thresholds, evenly distributed in
+      `[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins
+      is 1 less than `num_thresholds`. Using an even `num_thresholds` value
+      instead of an odd one may yield unfriendly edges for bins.
+    weights: Optional; If provided, a `Tensor` that has the same dtype as,
+      and broadcastable to, `predictions`. This tensor is multplied by counts.
+    name: Optional; variable_scope name. If not provided, the string
+      'precision_recall_at_equal_threshold' is used.
+    use_locking: Optional; If True, the op will be protected by a lock.
+      Otherwise, the behavior is undefined, but may exhibit less contention.
+      Defaults to True.
+
+  Returns:
+    result: A named tuple (See PrecisionRecallData within the implementation of
+      this function) with properties that are variables of shape
+      `[num_thresholds]`. The names of the properties are tp, fp, tn, fn,
+      precision, recall, thresholds.
+    update_op: An op that accumulates values.
+
+  Raises:
+    ValueError: If `predictions` and `labels` have mismatched shapes, or if
+      `weights` is not `None` and its shape doesn't match `predictions`, or if
+      `includes` contains invalid keys.
+  """
+  # Disable the invalid-name checker so that we can capitalize the name.
+  # pylint: disable=invalid-name
+  PrecisionRecallData = collections_lib.namedtuple(
+      'PrecisionRecallData',
+      ['tp', 'fp', 'tn', 'fn', 'precision', 'recall', 'thresholds'])
+  # pylint: enable=invalid-name
+
+  if num_thresholds is None:
+    num_thresholds = 201
+
+  if weights is None:
+    weights = 1.0
+
+  if use_locking is None:
+    use_locking = True
+
+  check_ops.assert_type(labels, dtypes.bool)
+
+  dtype = predictions.dtype
+  with variable_scope.variable_scope(name,
+                                     'precision_recall_at_equal_thresholds',
+                                     (labels, predictions, weights)):
+    # Make sure that predictions are within [0.0, 1.0].
+    with ops.control_dependencies([
+        check_ops.assert_greater_equal(
+            predictions,
+            math_ops.cast(0.0, dtype=predictions.dtype),
+            message='predictions must be in [0, 1]'),
+        check_ops.assert_less_equal(
+            predictions,
+            math_ops.cast(1.0, dtype=predictions.dtype),
+            message='predictions must be in [0, 1]')
+    ]):
+      predictions, labels, weights = _remove_squeezable_dimensions(
+          predictions=predictions, labels=labels, weights=weights)
+
+    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+
+    # We cast to float to ensure we have 0.0 or 1.0.
+    f_labels = math_ops.cast(labels, dtype)
+
+    # Get weighted true/false labels.
+    true_labels = f_labels * weights
+    false_labels = (1.0 - f_labels) * weights
+
+    # Flatten predictions and labels.
+    predictions = array_ops.reshape(predictions, [-1])
+    true_labels = array_ops.reshape(true_labels, [-1])
+    false_labels = array_ops.reshape(false_labels, [-1])
+
+    # To compute TP/FP/TN/FN, we are measuring a binary classifier
+    #   C(t) = (predictions >= t)
+    # at each threshold 't'. So we have
+    #   TP(t) = sum( C(t) * true_labels )
+    #   FP(t) = sum( C(t) * false_labels )
+    #
+    # But, computing C(t) requires computation for each t. To make it fast,
+    # observe that C(t) is a cumulative integral, and so if we have
+    #   thresholds = [t_0, ..., t_{n-1}];  t_0 < ... < t_{n-1}
+    # where n = num_thresholds, and if we can compute the bucket function
+    #   B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
+    # then we get
+    #   C(t_i) = sum( B(j), j >= i )
+    # which is the reversed cumulative sum in tf.cumsum().
+    #
+    # We can compute B(i) efficiently by taking advantage of the fact that
+    # our thresholds are evenly distributed, in that
+    #   width = 1.0 / (num_thresholds - 1)
+    #   thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
+    # Given a prediction value p, we can map it to its bucket by
+    #   bucket_index(p) = floor( p * (num_thresholds - 1) )
+    # so we can use tf.scatter_add() to update the buckets in one pass.
+    #
+    # This implementation exhibits a run time and space complexity of O(T + N),
+    # where T is the number of thresholds and N is the size of predictions.
+    # Metrics that rely on _streaming_confusion_matrix_at_thresholds instead
+    # exhibit a complexity of O(T * N).
+
+    # Compute the bucket indices for each prediction value.
+    bucket_indices = math_ops.cast(
+        math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32)
+
+    with ops.name_scope('variables'):
+      tp_buckets_v = _create_local(
+          'tp_buckets', shape=[num_thresholds], dtype=dtype)
+      fp_buckets_v = _create_local(
+          'fp_buckets', shape=[num_thresholds], dtype=dtype)
+
+    with ops.name_scope('update_op'):
+      update_tp = state_ops.scatter_add(
+          tp_buckets_v, bucket_indices, true_labels, use_locking=use_locking)
+      update_fp = state_ops.scatter_add(
+          fp_buckets_v, bucket_indices, false_labels, use_locking=use_locking)
+
+    # Set up the cumulative sums to compute the actual metrics.
+    tp = math_ops.cumsum(tp_buckets_v, reverse=True, name='tp')
+    fp = math_ops.cumsum(fp_buckets_v, reverse=True, name='fp')
+    # fn = sum(true_labels) - tp
+    #    = sum(tp_buckets) - tp
+    #    = tp[0] - tp
+    # Similarly,
+    # tn = fp[0] - fp
+    tn = fp[0] - fp
+    fn = tp[0] - tp
+
+    # We use a minimum to prevent division by 0.
+    epsilon = 1e-7
+    precision = tp / math_ops.maximum(epsilon, tp + fp)
+    recall = tp / math_ops.maximum(epsilon, tp + fn)
+
+    result = PrecisionRecallData(
+        tp=tp,
+        fp=fp,
+        tn=tn,
+        fn=fn,
+        precision=precision,
+        recall=recall,
+        thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds))
+    update_op = control_flow_ops.group(update_tp, update_fp)
+    return result, update_op
+
+
 def streaming_specificity_at_sensitivity(predictions,
                                          labels,
                                          sensitivity,
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index f288fceef6c..f24bec7f115 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -1970,6 +1970,170 @@ class StreamingAUCTest(test.TestCase):
         self.assertAlmostEqual(expected_auc, auc.eval(), 2)
 
 
+class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
+
+  def setUp(self):
+    np.random.seed(1)
+    ops.reset_default_graph()
+
+  def _testResultsEqual(self, expected_dict, gotten_result):
+    """Tests that 2 results (dicts) represent the same data.
+
+    Args:
+      expected_dict: A dictionary with keys that are the names of properties
+        of PrecisionRecallData and whose values are lists of floats.
+      gotten_result: A PrecisionRecallData object.
+    """
+    gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()}
+    self.assertItemsEqual(
+        list(expected_dict.keys()), list(gotten_dict.keys()))
+
+    for key, expected_values in expected_dict.items():
+      self.assertAllClose(expected_values, gotten_dict[key])
+
+  def _testCase(self, predictions, labels, expected_result, weights=None):
+    """Performs a test given a certain scenario of labels, predictions, weights.
+
+    Args:
+      predictions: The predictions tensor. Of type float32.
+      labels: The labels tensor. Of type bool.
+      expected_result: The expected result (dict) that maps to tensors.
+      weights: Optional weights tensor.
+    """
+    with self.test_session() as sess:
+      predictions_tensor = constant_op.constant(
+          predictions, dtype=dtypes_lib.float32)
+      labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool)
+      weights_tensor = None
+      if weights:
+        weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32)
+      gotten_result, update_op = (
+          metric_ops.streaming_precision_recall_at_equal_thresholds(
+              predictions=predictions_tensor,
+              labels=labels_tensor,
+              num_thresholds=3,
+              weights=weights_tensor))
+
+      sess.run(variables.local_variables_initializer())
+      sess.run(update_op)
+
+      self._testResultsEqual(expected_result, gotten_result)
+
+  def testVars(self):
+    metric_ops.streaming_precision_recall_at_equal_thresholds(
+        predictions=constant_op.constant([0.42], dtype=dtypes_lib.float32),
+        labels=constant_op.constant([True], dtype=dtypes_lib.bool))
+    _assert_local_variables(
+        self,
+        (
+            'precision_recall_at_equal_thresholds/variables/tp_buckets:0',
+            'precision_recall_at_equal_thresholds/variables/fp_buckets:0'
+        ))
+
+  def testVarsWithName(self):
+    metric_ops.streaming_precision_recall_at_equal_thresholds(
+        predictions=constant_op.constant([0.42], dtype=dtypes_lib.float32),
+        labels=constant_op.constant([True], dtype=dtypes_lib.bool),
+        name='foo')
+    _assert_local_variables(
+        self, ('foo/variables/tp_buckets:0', 'foo/variables/fp_buckets:0'))
+
+  def testValuesAreIdempotent(self):
+    predictions = constant_op.constant(
+        np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32)
+    labels = constant_op.constant(
+        np.random.uniform(size=(10, 3)) > 0.5, dtype=dtypes_lib.bool)
+
+    result, update_op = (
+        metric_ops.streaming_precision_recall_at_equal_thresholds(
+            predictions=predictions, labels=labels))
+
+    with self.test_session() as sess:
+      # Run several updates.
+      sess.run(variables.local_variables_initializer())
+      for _ in range(3):
+        sess.run(update_op)
+
+      # Then verify idempotency.
+      initial_result = {k: value.eval().tolist() for k, value in
+                        result._asdict().items()}
+      for _ in range(3):
+        self._testResultsEqual(initial_result, result)
+
+  def testAllTruePositives(self):
+    self._testCase([[1]], [[True]], {
+        'tp': [1, 1, 1],
+        'fp': [0, 0, 0],
+        'tn': [0, 0, 0],
+        'fn': [0, 0, 0],
+        'precision': [1.0, 1.0, 1.0],
+        'recall': [1.0, 1.0, 1.0],
+        'thresholds': [0.0, 0.5, 1.0],
+    })
+
+  def testAllTrueNegatives(self):
+    self._testCase([[0]], [[False]], {
+        'tp': [0, 0, 0],
+        'fp': [1, 0, 0],
+        'tn': [0, 1, 1],
+        'fn': [0, 0, 0],
+        'precision': [0.0, 0.0, 0.0],
+        'recall': [0.0, 0.0, 0.0],
+        'thresholds': [0.0, 0.5, 1.0],
+    })
+
+  def testAllFalsePositives(self):
+    self._testCase([[1]], [[False]], {
+        'tp': [0, 0, 0],
+        'fp': [1, 1, 1],
+        'tn': [0, 0, 0],
+        'fn': [0, 0, 0],
+        'precision': [0.0, 0.0, 0.0],
+        'recall': [0.0, 0.0, 0.0],
+        'thresholds': [0.0, 0.5, 1.0],
+    })
+
+  def testAllFalseNegatives(self):
+    self._testCase([[0]], [[True]], {
+        'tp': [1, 0, 0],
+        'fp': [0, 0, 0],
+        'tn': [0, 0, 0],
+        'fn': [0, 1, 1],
+        'precision': [1.0, 0.0, 0.0],
+        'recall': [1.0, 0.0, 0.0],
+        'thresholds': [0.0, 0.5, 1.0],
+    })
+
+  def testManyValues(self):
+    self._testCase(
+        [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]],
+        [[True, False, False, True, True, True]],
+        {
+            'tp': [4, 3, 0],
+            'fp': [2, 0, 0],
+            'tn': [0, 2, 2],
+            'fn': [0, 1, 4],
+            'precision': [2.0 / 3.0, 1.0, 0.0],
+            'recall': [1.0, 0.75, 0.0],
+            'thresholds': [0.0, 0.5, 1.0],
+        })
+
+  def testManyValuesWithWeights(self):
+    self._testCase(
+        [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]],
+        [[True, False, False, True, True, True]],
+        {
+            'tp': [1.5, 1.5, 0.0],
+            'fp': [2.5, 0.0, 0.0],
+            'tn': [0.0, 2.5, 2.5],
+            'fn': [0.0, 0.0, 1.5],
+            'precision': [0.375, 1.0, 0.0],
+            'recall': [1.0, 1.0, 0.0],
+            'thresholds': [0.0, 0.5, 1.0],
+        },
+        weights=[0.0, 0.5, 2.0, 0.0, 0.5, 1.0])
+
+
 class StreamingSpecificityAtSensitivityTest(test.TestCase):
 
   def setUp(self):