Adds streaming_dynamic_auc to Tensorflow contrib metrics. This metric differs from streaming_auc because it uses every prediction as a threshold rather than linearly spaced fixed thresholds.
PiperOrigin-RevId: 175217002
This commit is contained in:
parent
f3f85e9aa0
commit
b11a790328
@ -27,6 +27,7 @@ See the @{$python/contrib.metrics} guide.
|
|||||||
@@streaming_false_negative_rate
|
@@streaming_false_negative_rate
|
||||||
@@streaming_false_negative_rate_at_thresholds
|
@@streaming_false_negative_rate_at_thresholds
|
||||||
@@streaming_auc
|
@@streaming_auc
|
||||||
|
@@streaming_dynamic_auc
|
||||||
@@streaming_curve_points
|
@@streaming_curve_points
|
||||||
@@streaming_recall_at_k
|
@@streaming_recall_at_k
|
||||||
@@streaming_mean_absolute_error
|
@@streaming_mean_absolute_error
|
||||||
@ -88,6 +89,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
|
|||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
|
||||||
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_dynamic_auc
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate_at_thresholds
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate_at_thresholds
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
|
||||||
|
@ -1178,6 +1178,154 @@ def streaming_auc(predictions,
|
|||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_dynamic_auc(labels, predictions, curve='ROC'):
|
||||||
|
"""Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
|
||||||
|
|
||||||
|
Computes the area under the ROC or PR curve using each prediction as a
|
||||||
|
threshold. This could be slow for large batches, but has the advantage of not
|
||||||
|
having its results degrade depending on the distribution of predictions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
labels: A `Tensor` of ground truth labels with the same shape as
|
||||||
|
`predictions` with values of 0 or 1 and type `int64`.
|
||||||
|
predictions: A 1-D `Tensor` of predictions whose values are `float64`.
|
||||||
|
curve: The name of the curve to be computed, 'ROC' for the Receiving
|
||||||
|
Operating Characteristic or 'PR' for the Precision-Recall curve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A scalar `Tensor` containing the area-under-curve value for the input.
|
||||||
|
"""
|
||||||
|
# Count the total number of positive and negative labels in the input.
|
||||||
|
size = array_ops.size(predictions)
|
||||||
|
total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32)
|
||||||
|
|
||||||
|
def continue_computing_dynamic_auc():
|
||||||
|
"""Continues dynamic auc computation, entered if labels are not all equal.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A scalar `Tensor` containing the area-under-curve value.
|
||||||
|
"""
|
||||||
|
# Sort the predictions descending, and the corresponding labels as well.
|
||||||
|
ordered_predictions, indices = nn.top_k(predictions, k=size)
|
||||||
|
ordered_labels = array_ops.gather(labels, indices)
|
||||||
|
|
||||||
|
# Get the counts of the unique ordered predictions.
|
||||||
|
_, _, counts = array_ops.unique_with_counts(ordered_predictions)
|
||||||
|
|
||||||
|
# Compute the indices of the split points between different predictions.
|
||||||
|
splits = math_ops.cast(
|
||||||
|
array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32)
|
||||||
|
|
||||||
|
# Count the positives to the left of the split indices.
|
||||||
|
positives = math_ops.cast(
|
||||||
|
array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]),
|
||||||
|
dtypes.int32)
|
||||||
|
true_positives = array_ops.gather(positives, splits)
|
||||||
|
if curve == 'ROC':
|
||||||
|
# Count the negatives to the left of every split point and the total
|
||||||
|
# number of negatives for computing the FPR.
|
||||||
|
false_positives = math_ops.subtract(splits, true_positives)
|
||||||
|
total_negative = size - total_positive
|
||||||
|
x_axis_values = math_ops.truediv(false_positives, total_negative)
|
||||||
|
y_axis_values = math_ops.truediv(true_positives, total_positive)
|
||||||
|
elif curve == 'PR':
|
||||||
|
x_axis_values = math_ops.truediv(true_positives, total_positive)
|
||||||
|
# For conformance, set precision to 1 when the number of positive
|
||||||
|
# classifications is 0.
|
||||||
|
y_axis_values = array_ops.where(
|
||||||
|
math_ops.greater(splits, 0),
|
||||||
|
math_ops.truediv(true_positives, splits),
|
||||||
|
array_ops.ones_like(true_positives, dtype=dtypes.float64))
|
||||||
|
|
||||||
|
# Calculate trapezoid areas.
|
||||||
|
heights = math_ops.add(y_axis_values[1:], y_axis_values[:-1]) / 2.0
|
||||||
|
widths = math_ops.abs(
|
||||||
|
math_ops.subtract(x_axis_values[1:], x_axis_values[:-1]))
|
||||||
|
return math_ops.reduce_sum(math_ops.multiply(heights, widths))
|
||||||
|
|
||||||
|
# If all the labels are the same, AUC isn't well-defined (but raising an
|
||||||
|
# exception seems excessive) so we return 0, otherwise we finish computing.
|
||||||
|
return control_flow_ops.cond(
|
||||||
|
math_ops.logical_or(
|
||||||
|
math_ops.equal(total_positive, 0),
|
||||||
|
math_ops.equal(total_positive, size)
|
||||||
|
),
|
||||||
|
true_fn=lambda: array_ops.constant(0, dtypes.float64),
|
||||||
|
false_fn=continue_computing_dynamic_auc)
|
||||||
|
|
||||||
|
|
||||||
|
def streaming_dynamic_auc(labels,
|
||||||
|
predictions,
|
||||||
|
curve='ROC',
|
||||||
|
metrics_collections=(),
|
||||||
|
updates_collections=(),
|
||||||
|
name=None):
|
||||||
|
"""Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
|
||||||
|
|
||||||
|
USAGE NOTE: this approach requires storing all of the predictions and labels
|
||||||
|
for a single evaluation in memory, so it may not be usable when the evaluation
|
||||||
|
batch size and/or the number of evaluation steps is very large.
|
||||||
|
|
||||||
|
Computes the area under the ROC or PR curve using each prediction as a
|
||||||
|
threshold. This has the advantage of being resilient to the distribution of
|
||||||
|
predictions by aggregating across batches, accumulating labels and predictions
|
||||||
|
and performing the final calculation using all of the concatenated values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
labels: A `Tensor` of ground truth labels with the same shape as `labels`
|
||||||
|
and with values of 0 or 1 whose values are castable to `int64`.
|
||||||
|
predictions: A `Tensor` of predictions whose values are castable to
|
||||||
|
`float64`. Will be flattened into a 1-D `Tensor`.
|
||||||
|
curve: The name of the curve for which to compute AUC, 'ROC' for the
|
||||||
|
Receiving Operating Characteristic or 'PR' for the Precision-Recall curve.
|
||||||
|
metrics_collections: An optional iterable of collections that `auc` should
|
||||||
|
be added to.
|
||||||
|
updates_collections: An optional iterable of collections that `update_op`
|
||||||
|
should be added to.
|
||||||
|
name: An optional name for the variable_scope that contains the metric
|
||||||
|
variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
auc: A scalar `Tensor` containing the current area-under-curve value.
|
||||||
|
update_op: An operation that concatenates the input labels and predictions
|
||||||
|
to the accumulated values.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `labels` and `predictions` have mismatched shapes or if
|
||||||
|
`curve` isn't a recognized curve type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if curve not in ['PR', 'ROC']:
|
||||||
|
raise ValueError('curve must be either ROC or PR, %s unknown' % curve)
|
||||||
|
|
||||||
|
with variable_scope.variable_scope(name, default_name='dynamic_auc'):
|
||||||
|
labels.get_shape().assert_is_compatible_with(predictions.get_shape())
|
||||||
|
predictions = array_ops.reshape(
|
||||||
|
math_ops.cast(predictions, dtypes.float64), [-1])
|
||||||
|
labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1])
|
||||||
|
with ops.control_dependencies([
|
||||||
|
check_ops.assert_greater_equal(
|
||||||
|
labels,
|
||||||
|
array_ops.zeros_like(labels, dtypes.int64),
|
||||||
|
message='labels must be 0 or 1, at least one is <0'),
|
||||||
|
check_ops.assert_less_equal(
|
||||||
|
labels,
|
||||||
|
array_ops.ones_like(labels, dtypes.int64),
|
||||||
|
message='labels must be 0 or 1, at least one is >1')
|
||||||
|
]):
|
||||||
|
preds_accum, update_preds = streaming_concat(predictions,
|
||||||
|
name='concat_preds')
|
||||||
|
labels_accum, update_labels = streaming_concat(labels,
|
||||||
|
name='concat_labels')
|
||||||
|
update_op = control_flow_ops.group(update_labels, update_preds)
|
||||||
|
auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve)
|
||||||
|
if updates_collections:
|
||||||
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
|
if metrics_collections:
|
||||||
|
ops.add_to_collections(metrics_collections, auc)
|
||||||
|
return auc, update_op
|
||||||
|
|
||||||
|
|
||||||
def streaming_precision_recall_at_equal_thresholds(predictions,
|
def streaming_precision_recall_at_equal_thresholds(predictions,
|
||||||
labels,
|
labels,
|
||||||
num_thresholds=None,
|
num_thresholds=None,
|
||||||
@ -3285,6 +3433,7 @@ __all__ = [
|
|||||||
'streaming_accuracy',
|
'streaming_accuracy',
|
||||||
'streaming_auc',
|
'streaming_auc',
|
||||||
'streaming_curve_points',
|
'streaming_curve_points',
|
||||||
|
'streaming_dynamic_auc',
|
||||||
'streaming_false_negative_rate',
|
'streaming_false_negative_rate',
|
||||||
'streaming_false_negative_rate_at_thresholds',
|
'streaming_false_negative_rate_at_thresholds',
|
||||||
'streaming_false_negatives',
|
'streaming_false_negatives',
|
||||||
|
@ -1708,6 +1708,34 @@ class StreamingCurvePointsTest(test.TestCase):
|
|||||||
[[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
|
[[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
|
||||||
|
|
||||||
|
|
||||||
|
def _np_auc(predictions, labels, weights=None):
|
||||||
|
"""Computes the AUC explicitly using Numpy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictions: an ndarray with shape [N].
|
||||||
|
labels: an ndarray with shape [N].
|
||||||
|
weights: an ndarray with shape [N].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the area under the ROC curve.
|
||||||
|
"""
|
||||||
|
if weights is None:
|
||||||
|
weights = np.ones(np.size(predictions))
|
||||||
|
is_positive = labels > 0
|
||||||
|
num_positives = np.sum(weights[is_positive])
|
||||||
|
num_negatives = np.sum(weights[~is_positive])
|
||||||
|
|
||||||
|
# Sort descending:
|
||||||
|
inds = np.argsort(-predictions)
|
||||||
|
|
||||||
|
sorted_labels = labels[inds]
|
||||||
|
sorted_weights = weights[inds]
|
||||||
|
is_positive = sorted_labels > 0
|
||||||
|
|
||||||
|
tp = np.cumsum(sorted_weights * is_positive) / num_positives
|
||||||
|
return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
|
||||||
|
|
||||||
|
|
||||||
class StreamingAUCTest(test.TestCase):
|
class StreamingAUCTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -1896,33 +1924,6 @@ class StreamingAUCTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAlmostEqual(1, auc.eval(), 6)
|
self.assertAlmostEqual(1, auc.eval(), 6)
|
||||||
|
|
||||||
def np_auc(self, predictions, labels, weights):
|
|
||||||
"""Computes the AUC explicitly using Numpy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
predictions: an ndarray with shape [N].
|
|
||||||
labels: an ndarray with shape [N].
|
|
||||||
weights: an ndarray with shape [N].
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
the area under the ROC curve.
|
|
||||||
"""
|
|
||||||
if weights is None:
|
|
||||||
weights = np.ones(np.size(predictions))
|
|
||||||
is_positive = labels > 0
|
|
||||||
num_positives = np.sum(weights[is_positive])
|
|
||||||
num_negatives = np.sum(weights[~is_positive])
|
|
||||||
|
|
||||||
# Sort descending:
|
|
||||||
inds = np.argsort(-predictions)
|
|
||||||
|
|
||||||
sorted_labels = labels[inds]
|
|
||||||
sorted_weights = weights[inds]
|
|
||||||
is_positive = sorted_labels > 0
|
|
||||||
|
|
||||||
tp = np.cumsum(sorted_weights * is_positive) / num_positives
|
|
||||||
return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
|
|
||||||
|
|
||||||
def testWithMultipleUpdates(self):
|
def testWithMultipleUpdates(self):
|
||||||
num_samples = 1000
|
num_samples = 1000
|
||||||
batch_size = 10
|
batch_size = 10
|
||||||
@ -1945,7 +1946,7 @@ class StreamingAUCTest(test.TestCase):
|
|||||||
|
|
||||||
for weights in (None, np.ones(num_samples), np.random.exponential(
|
for weights in (None, np.ones(num_samples), np.random.exponential(
|
||||||
scale=1.0, size=num_samples)):
|
scale=1.0, size=num_samples)):
|
||||||
expected_auc = self.np_auc(predictions, labels, weights)
|
expected_auc = _np_auc(predictions, labels, weights)
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
enqueue_ops = [[] for i in range(num_batches)]
|
enqueue_ops = [[] for i in range(num_batches)]
|
||||||
@ -1974,6 +1975,211 @@ class StreamingAUCTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(expected_auc, auc.eval(), 2)
|
self.assertAlmostEqual(expected_auc, auc.eval(), 2)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingDynamicAUCTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(StreamingDynamicAUCTest, self).setUp()
|
||||||
|
np.random.seed(1)
|
||||||
|
ops.reset_default_graph()
|
||||||
|
|
||||||
|
def testUnknownCurve(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'curve must be either ROC or PR, TEST_CURVE unknown'):
|
||||||
|
metrics.streaming_dynamic_auc(labels=array_ops.ones((10, 1)),
|
||||||
|
predictions=array_ops.ones((10, 1)),
|
||||||
|
curve='TEST_CURVE')
|
||||||
|
|
||||||
|
def testVars(self):
|
||||||
|
metrics.streaming_dynamic_auc(
|
||||||
|
labels=array_ops.ones((10, 1)), predictions=array_ops.ones((10, 1)))
|
||||||
|
_assert_metric_variables(self, ['dynamic_auc/concat_labels/array:0',
|
||||||
|
'dynamic_auc/concat_labels/size:0',
|
||||||
|
'dynamic_auc/concat_preds/array:0',
|
||||||
|
'dynamic_auc/concat_preds/size:0'])
|
||||||
|
|
||||||
|
def testMetricsCollection(self):
|
||||||
|
my_collection_name = '__metrics__'
|
||||||
|
auc, _ = metrics.streaming_dynamic_auc(
|
||||||
|
labels=array_ops.ones((10, 1)),
|
||||||
|
predictions=array_ops.ones((10, 1)),
|
||||||
|
metrics_collections=[my_collection_name])
|
||||||
|
self.assertEqual(ops.get_collection(my_collection_name), [auc])
|
||||||
|
|
||||||
|
def testUpdatesCollection(self):
|
||||||
|
my_collection_name = '__updates__'
|
||||||
|
_, update_op = metrics.streaming_dynamic_auc(
|
||||||
|
labels=array_ops.ones((10, 1)),
|
||||||
|
predictions=array_ops.ones((10, 1)),
|
||||||
|
updates_collections=[my_collection_name])
|
||||||
|
self.assertEqual(ops.get_collection(my_collection_name), [update_op])
|
||||||
|
|
||||||
|
def testValueTensorIsIdempotent(self):
|
||||||
|
predictions = random_ops.random_uniform(
|
||||||
|
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
|
||||||
|
labels = random_ops.random_uniform(
|
||||||
|
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
|
||||||
|
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
# Run several updates.
|
||||||
|
for _ in xrange(10):
|
||||||
|
sess.run(update_op)
|
||||||
|
# Then verify idempotency.
|
||||||
|
initial_auc = auc.eval()
|
||||||
|
for _ in xrange(10):
|
||||||
|
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
|
||||||
|
|
||||||
|
def testAllLabelsOnes(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions = constant_op.constant([1., 1., 1.])
|
||||||
|
labels = constant_op.constant([1, 1, 1])
|
||||||
|
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
sess.run(update_op)
|
||||||
|
self.assertEqual(0, auc.eval())
|
||||||
|
|
||||||
|
def testAllLabelsZeros(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions = constant_op.constant([1., 1., 1.])
|
||||||
|
labels = constant_op.constant([0, 0, 0])
|
||||||
|
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
sess.run(update_op)
|
||||||
|
self.assertEqual(0, auc.eval())
|
||||||
|
|
||||||
|
def testNonZeroOnePredictions(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions = constant_op.constant([2.5, -2.5, 2.5, -2.5],
|
||||||
|
dtype=dtypes_lib.float32)
|
||||||
|
labels = constant_op.constant([1, 0, 1, 0])
|
||||||
|
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
sess.run(update_op)
|
||||||
|
self.assertAlmostEqual(auc.eval(), 1.0)
|
||||||
|
|
||||||
|
def testAllCorrect(self):
|
||||||
|
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions = constant_op.constant(inputs)
|
||||||
|
labels = constant_op.constant(inputs)
|
||||||
|
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
sess.run(update_op)
|
||||||
|
self.assertEqual(1, auc.eval())
|
||||||
|
|
||||||
|
def testSomeCorrect(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions = constant_op.constant([1, 0, 1, 0])
|
||||||
|
labels = constant_op.constant([0, 1, 1, 0])
|
||||||
|
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
sess.run(update_op)
|
||||||
|
self.assertAlmostEqual(0.5, auc.eval())
|
||||||
|
|
||||||
|
def testAllIncorrect(self):
|
||||||
|
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
|
||||||
|
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
|
||||||
|
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
sess.run(update_op)
|
||||||
|
self.assertAlmostEqual(0, auc.eval())
|
||||||
|
|
||||||
|
def testExceptionOnIncompatibleShapes(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions = array_ops.ones([5])
|
||||||
|
labels = array_ops.zeros([6])
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
|
||||||
|
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
sess.run(update_op)
|
||||||
|
|
||||||
|
def testExceptionOnGreaterThanOneLabel(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
|
||||||
|
labels = constant_op.constant([2, 1, 0])
|
||||||
|
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
errors_impl.InvalidArgumentError,
|
||||||
|
'.*labels must be 0 or 1, at least one is >1.*'):
|
||||||
|
sess.run(update_op)
|
||||||
|
|
||||||
|
def testExceptionOnNegativeLabel(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
|
||||||
|
labels = constant_op.constant([1, 0, -1])
|
||||||
|
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
errors_impl.InvalidArgumentError,
|
||||||
|
'.*labels must be 0 or 1, at least one is <0.*'):
|
||||||
|
sess.run(update_op)
|
||||||
|
|
||||||
|
def testWithMultipleUpdates(self):
|
||||||
|
batch_size = 10
|
||||||
|
num_batches = 100
|
||||||
|
labels = np.array([])
|
||||||
|
predictions = np.array([])
|
||||||
|
tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32),
|
||||||
|
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
||||||
|
dtype=dtypes_lib.int32)
|
||||||
|
tf_predictions = variables.Variable(
|
||||||
|
array_ops.ones(batch_size),
|
||||||
|
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
||||||
|
dtype=dtypes_lib.float32)
|
||||||
|
auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
for _ in xrange(num_batches):
|
||||||
|
new_labels = np.random.randint(0, 2, size=batch_size)
|
||||||
|
noise = np.random.normal(0.0, scale=0.2, size=batch_size)
|
||||||
|
new_predictions = 0.4 + 0.2 * new_labels + noise
|
||||||
|
labels = np.concatenate([labels, new_labels])
|
||||||
|
predictions = np.concatenate([predictions, new_predictions])
|
||||||
|
sess.run(tf_labels.assign(new_labels))
|
||||||
|
sess.run(tf_predictions.assign(new_predictions))
|
||||||
|
sess.run(update_op)
|
||||||
|
expected_auc = _np_auc(predictions, labels)
|
||||||
|
self.assertAlmostEqual(expected_auc, auc.eval())
|
||||||
|
|
||||||
|
def testAUCPRReverseIncreasingPredictions(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions = constant_op.constant(
|
||||||
|
[0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32)
|
||||||
|
labels = constant_op.constant([0, 0, 1, 1])
|
||||||
|
auc, update_op = metrics.streaming_dynamic_auc(
|
||||||
|
labels, predictions, curve='PR')
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
sess.run(update_op)
|
||||||
|
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5)
|
||||||
|
|
||||||
|
def testAUCPRJumbledPredictions(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions = constant_op.constant(
|
||||||
|
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32)
|
||||||
|
labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1])
|
||||||
|
auc, update_op = metrics.streaming_dynamic_auc(
|
||||||
|
labels, predictions, curve='PR')
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
sess.run(update_op)
|
||||||
|
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6)
|
||||||
|
|
||||||
|
def testAUCPRPredictionsLessThanHalf(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions = constant_op.constant(
|
||||||
|
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
|
||||||
|
shape=(1, 7),
|
||||||
|
dtype=dtypes_lib.float32)
|
||||||
|
labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
|
||||||
|
auc, update_op = metrics.streaming_dynamic_auc(
|
||||||
|
labels, predictions, curve='PR')
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
sess.run(update_op)
|
||||||
|
self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5)
|
||||||
|
|
||||||
|
|
||||||
class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
|
class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user