Adds streaming_dynamic_auc to Tensorflow contrib metrics. This metric differs from streaming_auc because it uses every prediction as a threshold rather than linearly spaced fixed thresholds.

PiperOrigin-RevId: 175217002
This commit is contained in:
A. Unique TensorFlower 2017-11-09 14:55:09 -08:00 committed by TensorFlower Gardener
parent f3f85e9aa0
commit b11a790328
3 changed files with 385 additions and 28 deletions

View File

@ -27,6 +27,7 @@ See the @{$python/contrib.metrics} guide.
@@streaming_false_negative_rate
@@streaming_false_negative_rate_at_thresholds
@@streaming_auc
@@streaming_dynamic_auc
@@streaming_curve_points
@@streaming_recall_at_k
@@streaming_mean_absolute_error
@ -88,6 +89,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_dynamic_auc
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate_at_thresholds
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives

View File

@ -1178,6 +1178,154 @@ def streaming_auc(predictions,
name=name)
def _compute_dynamic_auc(labels, predictions, curve='ROC'):
"""Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
Computes the area under the ROC or PR curve using each prediction as a
threshold. This could be slow for large batches, but has the advantage of not
having its results degrade depending on the distribution of predictions.
Args:
labels: A `Tensor` of ground truth labels with the same shape as
`predictions` with values of 0 or 1 and type `int64`.
predictions: A 1-D `Tensor` of predictions whose values are `float64`.
curve: The name of the curve to be computed, 'ROC' for the Receiving
Operating Characteristic or 'PR' for the Precision-Recall curve.
Returns:
A scalar `Tensor` containing the area-under-curve value for the input.
"""
# Count the total number of positive and negative labels in the input.
size = array_ops.size(predictions)
total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32)
def continue_computing_dynamic_auc():
"""Continues dynamic auc computation, entered if labels are not all equal.
Returns:
A scalar `Tensor` containing the area-under-curve value.
"""
# Sort the predictions descending, and the corresponding labels as well.
ordered_predictions, indices = nn.top_k(predictions, k=size)
ordered_labels = array_ops.gather(labels, indices)
# Get the counts of the unique ordered predictions.
_, _, counts = array_ops.unique_with_counts(ordered_predictions)
# Compute the indices of the split points between different predictions.
splits = math_ops.cast(
array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32)
# Count the positives to the left of the split indices.
positives = math_ops.cast(
array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]),
dtypes.int32)
true_positives = array_ops.gather(positives, splits)
if curve == 'ROC':
# Count the negatives to the left of every split point and the total
# number of negatives for computing the FPR.
false_positives = math_ops.subtract(splits, true_positives)
total_negative = size - total_positive
x_axis_values = math_ops.truediv(false_positives, total_negative)
y_axis_values = math_ops.truediv(true_positives, total_positive)
elif curve == 'PR':
x_axis_values = math_ops.truediv(true_positives, total_positive)
# For conformance, set precision to 1 when the number of positive
# classifications is 0.
y_axis_values = array_ops.where(
math_ops.greater(splits, 0),
math_ops.truediv(true_positives, splits),
array_ops.ones_like(true_positives, dtype=dtypes.float64))
# Calculate trapezoid areas.
heights = math_ops.add(y_axis_values[1:], y_axis_values[:-1]) / 2.0
widths = math_ops.abs(
math_ops.subtract(x_axis_values[1:], x_axis_values[:-1]))
return math_ops.reduce_sum(math_ops.multiply(heights, widths))
# If all the labels are the same, AUC isn't well-defined (but raising an
# exception seems excessive) so we return 0, otherwise we finish computing.
return control_flow_ops.cond(
math_ops.logical_or(
math_ops.equal(total_positive, 0),
math_ops.equal(total_positive, size)
),
true_fn=lambda: array_ops.constant(0, dtypes.float64),
false_fn=continue_computing_dynamic_auc)
def streaming_dynamic_auc(labels,
predictions,
curve='ROC',
metrics_collections=(),
updates_collections=(),
name=None):
"""Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
USAGE NOTE: this approach requires storing all of the predictions and labels
for a single evaluation in memory, so it may not be usable when the evaluation
batch size and/or the number of evaluation steps is very large.
Computes the area under the ROC or PR curve using each prediction as a
threshold. This has the advantage of being resilient to the distribution of
predictions by aggregating across batches, accumulating labels and predictions
and performing the final calculation using all of the concatenated values.
Args:
labels: A `Tensor` of ground truth labels with the same shape as `labels`
and with values of 0 or 1 whose values are castable to `int64`.
predictions: A `Tensor` of predictions whose values are castable to
`float64`. Will be flattened into a 1-D `Tensor`.
curve: The name of the curve for which to compute AUC, 'ROC' for the
Receiving Operating Characteristic or 'PR' for the Precision-Recall curve.
metrics_collections: An optional iterable of collections that `auc` should
be added to.
updates_collections: An optional iterable of collections that `update_op`
should be added to.
name: An optional name for the variable_scope that contains the metric
variables.
Returns:
auc: A scalar `Tensor` containing the current area-under-curve value.
update_op: An operation that concatenates the input labels and predictions
to the accumulated values.
Raises:
ValueError: If `labels` and `predictions` have mismatched shapes or if
`curve` isn't a recognized curve type.
"""
if curve not in ['PR', 'ROC']:
raise ValueError('curve must be either ROC or PR, %s unknown' % curve)
with variable_scope.variable_scope(name, default_name='dynamic_auc'):
labels.get_shape().assert_is_compatible_with(predictions.get_shape())
predictions = array_ops.reshape(
math_ops.cast(predictions, dtypes.float64), [-1])
labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1])
with ops.control_dependencies([
check_ops.assert_greater_equal(
labels,
array_ops.zeros_like(labels, dtypes.int64),
message='labels must be 0 or 1, at least one is <0'),
check_ops.assert_less_equal(
labels,
array_ops.ones_like(labels, dtypes.int64),
message='labels must be 0 or 1, at least one is >1')
]):
preds_accum, update_preds = streaming_concat(predictions,
name='concat_preds')
labels_accum, update_labels = streaming_concat(labels,
name='concat_labels')
update_op = control_flow_ops.group(update_labels, update_preds)
auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
if metrics_collections:
ops.add_to_collections(metrics_collections, auc)
return auc, update_op
def streaming_precision_recall_at_equal_thresholds(predictions,
labels,
num_thresholds=None,
@ -3285,6 +3433,7 @@ __all__ = [
'streaming_accuracy',
'streaming_auc',
'streaming_curve_points',
'streaming_dynamic_auc',
'streaming_false_negative_rate',
'streaming_false_negative_rate_at_thresholds',
'streaming_false_negatives',

View File

@ -1708,6 +1708,34 @@ class StreamingCurvePointsTest(test.TestCase):
[[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
def _np_auc(predictions, labels, weights=None):
"""Computes the AUC explicitly using Numpy.
Args:
predictions: an ndarray with shape [N].
labels: an ndarray with shape [N].
weights: an ndarray with shape [N].
Returns:
the area under the ROC curve.
"""
if weights is None:
weights = np.ones(np.size(predictions))
is_positive = labels > 0
num_positives = np.sum(weights[is_positive])
num_negatives = np.sum(weights[~is_positive])
# Sort descending:
inds = np.argsort(-predictions)
sorted_labels = labels[inds]
sorted_weights = weights[inds]
is_positive = sorted_labels > 0
tp = np.cumsum(sorted_weights * is_positive) / num_positives
return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
class StreamingAUCTest(test.TestCase):
def setUp(self):
@ -1896,33 +1924,6 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(1, auc.eval(), 6)
def np_auc(self, predictions, labels, weights):
"""Computes the AUC explicitly using Numpy.
Args:
predictions: an ndarray with shape [N].
labels: an ndarray with shape [N].
weights: an ndarray with shape [N].
Returns:
the area under the ROC curve.
"""
if weights is None:
weights = np.ones(np.size(predictions))
is_positive = labels > 0
num_positives = np.sum(weights[is_positive])
num_negatives = np.sum(weights[~is_positive])
# Sort descending:
inds = np.argsort(-predictions)
sorted_labels = labels[inds]
sorted_weights = weights[inds]
is_positive = sorted_labels > 0
tp = np.cumsum(sorted_weights * is_positive) / num_positives
return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
def testWithMultipleUpdates(self):
num_samples = 1000
batch_size = 10
@ -1945,7 +1946,7 @@ class StreamingAUCTest(test.TestCase):
for weights in (None, np.ones(num_samples), np.random.exponential(
scale=1.0, size=num_samples)):
expected_auc = self.np_auc(predictions, labels, weights)
expected_auc = _np_auc(predictions, labels, weights)
with self.test_session() as sess:
enqueue_ops = [[] for i in range(num_batches)]
@ -1974,6 +1975,211 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(expected_auc, auc.eval(), 2)
class StreamingDynamicAUCTest(test.TestCase):
def setUp(self):
super(StreamingDynamicAUCTest, self).setUp()
np.random.seed(1)
ops.reset_default_graph()
def testUnknownCurve(self):
with self.assertRaisesRegexp(
ValueError, 'curve must be either ROC or PR, TEST_CURVE unknown'):
metrics.streaming_dynamic_auc(labels=array_ops.ones((10, 1)),
predictions=array_ops.ones((10, 1)),
curve='TEST_CURVE')
def testVars(self):
metrics.streaming_dynamic_auc(
labels=array_ops.ones((10, 1)), predictions=array_ops.ones((10, 1)))
_assert_metric_variables(self, ['dynamic_auc/concat_labels/array:0',
'dynamic_auc/concat_labels/size:0',
'dynamic_auc/concat_preds/array:0',
'dynamic_auc/concat_preds/size:0'])
def testMetricsCollection(self):
my_collection_name = '__metrics__'
auc, _ = metrics.streaming_dynamic_auc(
labels=array_ops.ones((10, 1)),
predictions=array_ops.ones((10, 1)),
metrics_collections=[my_collection_name])
self.assertEqual(ops.get_collection(my_collection_name), [auc])
def testUpdatesCollection(self):
my_collection_name = '__updates__'
_, update_op = metrics.streaming_dynamic_auc(
labels=array_ops.ones((10, 1)),
predictions=array_ops.ones((10, 1)),
updates_collections=[my_collection_name])
self.assertEqual(ops.get_collection(my_collection_name), [update_op])
def testValueTensorIsIdempotent(self):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
for _ in xrange(10):
sess.run(update_op)
# Then verify idempotency.
initial_auc = auc.eval()
for _ in xrange(10):
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
def testAllLabelsOnes(self):
with self.test_session() as sess:
predictions = constant_op.constant([1., 1., 1.])
labels = constant_op.constant([1, 1, 1])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, auc.eval())
def testAllLabelsZeros(self):
with self.test_session() as sess:
predictions = constant_op.constant([1., 1., 1.])
labels = constant_op.constant([0, 0, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, auc.eval())
def testNonZeroOnePredictions(self):
with self.test_session() as sess:
predictions = constant_op.constant([2.5, -2.5, 2.5, -2.5],
dtype=dtypes_lib.float32)
labels = constant_op.constant([1, 0, 1, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAlmostEqual(auc.eval(), 1.0)
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
with self.test_session() as sess:
predictions = constant_op.constant(inputs)
labels = constant_op.constant(inputs)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, auc.eval())
def testSomeCorrect(self):
with self.test_session() as sess:
predictions = constant_op.constant([1, 0, 1, 0])
labels = constant_op.constant([0, 1, 1, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAlmostEqual(0.5, auc.eval())
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
with self.test_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAlmostEqual(0, auc.eval())
def testExceptionOnIncompatibleShapes(self):
with self.test_session() as sess:
predictions = array_ops.ones([5])
labels = array_ops.zeros([6])
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
sess.run(variables.local_variables_initializer())
sess.run(update_op)
def testExceptionOnGreaterThanOneLabel(self):
with self.test_session() as sess:
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
labels = constant_op.constant([2, 1, 0])
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
sess.run(variables.local_variables_initializer())
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
'.*labels must be 0 or 1, at least one is >1.*'):
sess.run(update_op)
def testExceptionOnNegativeLabel(self):
with self.test_session() as sess:
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
labels = constant_op.constant([1, 0, -1])
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
sess.run(variables.local_variables_initializer())
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
'.*labels must be 0 or 1, at least one is <0.*'):
sess.run(update_op)
def testWithMultipleUpdates(self):
batch_size = 10
num_batches = 100
labels = np.array([])
predictions = np.array([])
tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.int32)
tf_predictions = variables.Variable(
array_ops.ones(batch_size),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
noise = np.random.normal(0.0, scale=0.2, size=batch_size)
new_predictions = 0.4 + 0.2 * new_labels + noise
labels = np.concatenate([labels, new_labels])
predictions = np.concatenate([predictions, new_predictions])
sess.run(tf_labels.assign(new_labels))
sess.run(tf_predictions.assign(new_predictions))
sess.run(update_op)
expected_auc = _np_auc(predictions, labels)
self.assertAlmostEqual(expected_auc, auc.eval())
def testAUCPRReverseIncreasingPredictions(self):
with self.test_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1])
auc, update_op = metrics.streaming_dynamic_auc(
labels, predictions, curve='PR')
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5)
def testAUCPRJumbledPredictions(self):
with self.test_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1])
auc, update_op = metrics.streaming_dynamic_auc(
labels, predictions, curve='PR')
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6)
def testAUCPRPredictionsLessThanHalf(self):
with self.test_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
auc, update_op = metrics.streaming_dynamic_auc(
labels, predictions, curve='PR')
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5)
class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
def setUp(self):