Moves most metrics from contrib into core.

Change: 140914784
This commit is contained in:
A. Unique TensorFlower 2016-12-02 17:43:26 -08:00 committed by TensorFlower Gardener
parent ccf6cf533c
commit 0e5015bb7d
13 changed files with 6286 additions and 645 deletions

View File

@ -33,19 +33,6 @@ py_test(
],
)
py_test(
name = "confusion_matrix_ops_test",
size = "medium",
srcs = ["python/kernel_tests/confusion_matrix_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":metrics_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
py_test(
name = "histogram_ops_test",
size = "medium",

View File

@ -133,7 +133,6 @@ labels and predictions tensors and results in a weighted average of the metric.
@@auc_using_histogram
@@accuracy
@@confusion_matrix
@@aggregate_metrics
@@aggregate_metric_map

View File

@ -18,93 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework import tensor_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import confusion_matrix as cm
def confusion_matrix(predictions, labels, num_classes=None, dtype=dtypes.int32,
def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32,
name=None, weights=None):
"""Computes the confusion matrix from predictions and labels.
Calculate the Confusion Matrix for a pair of prediction and
label 1-D int arrays.
The matrix rows represent the prediction labels and the columns
represents the real labels. The confusion matrix is always a 2-D array
of shape `[n, n]`, where `n` is the number of valid labels for a given
classification task. Both prediction and labels must be 1-D arrays of
the same shape in order for this function to work.
If `num_classes` is None, then `num_classes` will be set to the one plus
the maximum value in either predictions or labels.
Class labels are expected to start at 0. E.g., if `num_classes` was
three, then the possible labels would be `[0, 1, 2]`.
If `weights` is not `None`, then each prediction contributes its
corresponding weight to the total value of the confusion matrix cell.
For example:
```python
tf.contrib.metrics.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
[[0 0 0 0 0]
[0 0 1 0 0]
[0 0 1 0 0]
[0 0 0 0 0]
[0 0 0 0 1]]
```
Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
resulting in a 5x5 confusion matrix.
Args:
predictions: A 1-D array representing the predictions for a given
classification.
labels: A 1-D representing the real labels for the classification task.
num_classes: The possible number of labels the classification task can
have. If this value is not provided, it will be calculated
using both predictions and labels array.
dtype: Data type of the confusion matrix.
name: Scope name.
weights: An optional `Tensor` whose shape matches `predictions`.
Returns:
A k X k matrix representing the confusion matrix, where k is the number of
possible labels in the classification task.
Raises:
ValueError: If both predictions and labels are not 1-D vectors and have
mismatched shapes, or if `weights` is not `None` and its shape doesn't
match `predictions`.
"""
with ops.name_scope(name, 'confusion_matrix',
[predictions, labels, num_classes]) as name:
predictions, labels = tensor_util.remove_squeezable_dimensions(
ops.convert_to_tensor(
predictions, name='predictions'),
ops.convert_to_tensor(labels, name='labels'))
predictions = math_ops.cast(predictions, dtypes.int64)
labels = math_ops.cast(labels, dtypes.int64)
if num_classes is None:
num_classes = math_ops.maximum(math_ops.reduce_max(predictions),
math_ops.reduce_max(labels)) + 1
if weights is not None:
predictions.get_shape().assert_is_compatible_with(weights.get_shape())
weights = math_ops.cast(weights, dtype)
shape = array_ops.pack([num_classes, num_classes])
indices = array_ops.transpose(array_ops.pack([predictions, labels]))
values = (array_ops.ones_like(predictions, dtype)
if weights is None else weights)
cm_sparse = sparse_tensor.SparseTensor(
indices=indices, values=values, shape=math_ops.to_int64(shape))
zero_matrix = array_ops.zeros(math_ops.to_int32(shape), dtype)
return sparse_ops.sparse_add(zero_matrix, cm_sparse)
"""Deprecated. Use tf.confusion_matrix instead."""
return cm.confusion_matrix(labels=labels, predictions=predictions,
num_classes=num_classes, dtype=dtype, name=name,
weights=weights)

View File

@ -25,7 +25,6 @@ from __future__ import print_function
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import tensor_util
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
from tensorflow.contrib.metrics.python.ops import set_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -34,6 +33,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
@ -178,16 +178,10 @@ def streaming_true_positives(predictions, labels, weights=None,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(
name, 'true_positives', (predictions, labels, weights)):
predictions = ops.convert_to_tensor(predictions)
labels = ops.convert_to_tensor(labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
is_true_positive = math_ops.logical_and(math_ops.equal(labels, 1),
math_ops.equal(predictions, 1))
return _count_condition(is_true_positive, weights, metrics_collections,
updates_collections)
return metrics.true_positives(
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_true_negatives(predictions, labels, weights=None,
@ -262,16 +256,10 @@ def streaming_false_positives(predictions, labels, weights=None,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(
name, 'false_positives', (predictions, labels, weights)):
predictions = ops.convert_to_tensor(predictions)
labels = ops.convert_to_tensor(labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
is_false_positive = math_ops.logical_and(math_ops.equal(labels, 0),
math_ops.equal(predictions, 1))
return _count_condition(is_false_positive, weights, metrics_collections,
updates_collections)
return metrics.false_positives(
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_false_negatives(predictions, labels, weights=None,
@ -303,16 +291,10 @@ def streaming_false_negatives(predictions, labels, weights=None,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
"""
with variable_scope.variable_scope(
name, 'false_negatives', (predictions, labels, weights)):
predictions = ops.convert_to_tensor(predictions)
labels = ops.convert_to_tensor(labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
is_false_negative = math_ops.logical_and(math_ops.equal(labels, 1),
math_ops.equal(predictions, 0))
return _count_condition(is_false_negative, weights, metrics_collections,
updates_collections)
return metrics.false_negatives(
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def _broadcast_weights(weights, values):
@ -376,33 +358,9 @@ def streaming_mean(values, weights=None, metrics_collections=None,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
"""
with variable_scope.variable_scope(name, 'mean', (values, weights)):
values = math_ops.to_float(values)
total = _create_local('total', shape=[])
count = _create_local('count', shape=[])
if weights is not None:
weights = math_ops.to_float(weights)
values = math_ops.mul(values, weights)
num_values = math_ops.reduce_sum(_broadcast_weights(weights, values))
else:
num_values = math_ops.to_float(array_ops.size(values))
total_compute_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
count_compute_op = state_ops.assign_add(count, num_values)
mean = _safe_div(total, count, 'value')
with ops.control_dependencies([total_compute_op, count_compute_op]):
update_op = _safe_div(total, count, 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, mean)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return mean, update_op
return metrics.mean(
values=values, weights=weights, metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_mean_tensor(values, weights=None, metrics_collections=None,
@ -445,36 +403,9 @@ def streaming_mean_tensor(values, weights=None, metrics_collections=None,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
"""
with variable_scope.variable_scope(name, 'mean', (values, weights)):
total = _create_local('total_tensor', shape=values.get_shape())
count = _create_local('count_tensor', shape=values.get_shape())
num_values = array_ops.ones_like(values)
if weights is not None:
weights = math_ops.to_float(weights)
values = math_ops.mul(values, weights)
num_values = math_ops.mul(num_values, weights)
total_compute_op = state_ops.assign_add(total, values)
count_compute_op = state_ops.assign_add(count, num_values)
def compute_mean(total, count, name):
non_zero_count = math_ops.maximum(count,
array_ops.ones_like(count),
name=name)
return math_ops.truediv(total, non_zero_count, name=name)
mean = compute_mean(total, count, 'value')
with ops.control_dependencies([total_compute_op, count_compute_op]):
update_op = compute_mean(total, count, 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, mean)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return mean, update_op
return metrics.mean_tensor(
values=values, weights=weights, metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_accuracy(predictions, labels, weights=None,
@ -520,14 +451,10 @@ def streaming_accuracy(predictions, labels, weights=None,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
predictions, labels, weights = _remove_squeezable_dimensions(
predictions, labels, weights=weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
if labels.dtype != predictions.dtype:
predictions = math_ops.cast(predictions, labels.dtype)
is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
return streaming_mean(is_correct, weights, metrics_collections,
updates_collections, name or 'accuracy')
return metrics.accuracy(
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_precision(predictions, labels, weights=None,
@ -572,39 +499,10 @@ def streaming_precision(predictions, labels, weights=None,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(
name, 'precision', (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
true_positives, true_positives_update_op = streaming_true_positives(
predictions, labels, weights, metrics_collections=None,
updates_collections=None, name=None)
false_positives, false_positives_update_op = streaming_false_positives(
predictions, labels, weights, metrics_collections=None,
updates_collections=None, name=None)
def compute_precision(name):
return array_ops.where(
math_ops.greater(true_positives + false_positives, 0),
math_ops.div(true_positives, true_positives + false_positives),
0,
name)
precision = compute_precision('value')
with ops.control_dependencies([true_positives_update_op,
false_positives_update_op]):
update_op = compute_precision('update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, precision)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return precision, update_op
return metrics.precision(
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_recall(predictions, labels, weights=None,
@ -647,38 +545,10 @@ def streaming_recall(predictions, labels, weights=None,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(
name, 'recall', (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
true_positives, true_positives_update_op = streaming_true_positives(
predictions, labels, weights, metrics_collections=None,
updates_collections=None, name=None)
false_negatives, false_negatives_update_op = streaming_false_negatives(
predictions, labels, weights, metrics_collections=None,
updates_collections=None, name=None)
def compute_recall(true_positives, false_negatives, name):
return array_ops.where(
math_ops.greater(true_positives + false_negatives, 0),
math_ops.div(true_positives, true_positives + false_negatives),
0,
name)
recall = compute_recall(true_positives, false_negatives, 'value')
with ops.control_dependencies([true_positives_update_op,
false_negatives_update_op]):
update_op = compute_recall(true_positives, false_negatives, 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, recall)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return recall, update_op
return metrics.recall(
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def _streaming_confusion_matrix_at_thresholds(
@ -903,50 +773,10 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(
name, 'auc', (predictions, labels, weights)):
if curve != 'ROC' and curve != 'PR':
raise ValueError('curve must be either ROC or PR, %s unknown' %
(curve))
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds-2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _streaming_confusion_matrix_at_thresholds(
predictions, labels, thresholds, weights)
# Add epsilons to avoid dividing by 0.
epsilon = 1.0e-6
def compute_auc(tp, fn, tn, fp, name):
"""Computes the roc-auc or pr-auc based on confusion counts."""
recall = math_ops.div(tp + epsilon, tp + fn + epsilon)
if curve == 'ROC':
fp_rate = math_ops.div(fp, fp + tn + epsilon)
x = fp_rate
y = recall
else: # curve == 'PR'.
precision = math_ops.div(tp + epsilon, tp + fp + epsilon)
x = recall
y = precision
return math_ops.reduce_sum(math_ops.mul(
x[:num_thresholds - 1] - x[1:],
(y[:num_thresholds - 1] + y[1:]) / 2.), name=name)
# sum up the areas of all the trapeziums
auc = compute_auc(
values['tp'], values['fn'], values['tn'], values['fp'], 'value')
update_op = compute_auc(
update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'],
'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, auc)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return auc, update_op
return metrics.auc(
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections, num_thresholds=num_thresholds,
curve=curve, updates_collections=updates_collections, name=name)
def streaming_specificity_at_sensitivity(
@ -998,60 +828,11 @@ def streaming_specificity_at_sensitivity(
`sensitivity` is not between 0 and 1, or if either `metrics_collections`
or `updates_collections` are not a list or tuple.
"""
if sensitivity < 0 or sensitivity > 1:
raise ValueError('`sensitivity` must be in the range [0, 1].')
with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
(predictions, labels, weights)):
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds-2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
values, update_ops = _streaming_confusion_matrix_at_thresholds(
predictions, labels, thresholds, weights)
tp = values['tp']
fn = values['fn']
tn = values['tn']
fp = values['fp']
def compute_specificity_at_sensitivity(name):
"""Computes the specificity at the given sensitivity.
Args:
name: The name of the operation.
Returns:
The specificity using the aggregated values.
"""
sensitivities = math_ops.div(tp, tp + fn + kepsilon)
# We'll need to use this trick until tf.argmax allows us to specify
# whether we should use the first or last index in case of ties.
min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
indices_at_minval = math_ops.equal(
math_ops.abs(sensitivities - sensitivity), min_val)
indices_at_minval = math_ops.to_int64(indices_at_minval)
indices_at_minval = math_ops.cumsum(indices_at_minval)
tf_index = math_ops.argmax(indices_at_minval, 0)
tf_index = math_ops.cast(tf_index, dtypes.int32)
# Now, we have the implicit threshold, so compute the specificity:
return math_ops.div(tn[tf_index],
tn[tf_index] + fp[tf_index] + kepsilon,
name)
specificity = compute_specificity_at_sensitivity('value')
with ops.control_dependencies(update_ops.values()):
update_op = compute_specificity_at_sensitivity('update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, specificity)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return specificity, update_op
return metrics.specificity_at_sensitivity(
sensitivity=sensitivity, num_thresholds=num_thresholds,
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_sensitivity_at_specificity(
@ -1103,44 +884,11 @@ def streaming_sensitivity_at_specificity(
`specificity` is not between 0 and 1, or if either `metrics_collections`
or `updates_collections` are not a list or tuple.
"""
if specificity < 0 or specificity > 1:
raise ValueError('`specificity` must be in the range [0, 1].')
with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
(predictions, labels, weights)):
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds-2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _streaming_confusion_matrix_at_thresholds(
predictions, labels, thresholds, weights)
tp = values['tp']
fn = values['fn']
tn = values['tn']
fp = values['fp']
def compute_sensitivity_at_specificity(name):
specificities = math_ops.div(tn, tn + fp + kepsilon)
tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
tf_index = math_ops.cast(tf_index, dtypes.int32)
# Now, we have the implicit threshold, so compute the sensitivity:
return math_ops.div(tp[tf_index],
tp[tf_index] + fn[tf_index] + kepsilon,
name)
sensitivity = compute_sensitivity_at_specificity('value')
with ops.control_dependencies(update_ops.values()):
update_op = compute_sensitivity_at_specificity('update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, sensitivity)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return sensitivity, update_op
return metrics.sensitivity_at_specificity(
specificity=specificity, num_thresholds=num_thresholds,
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_precision_at_thresholds(predictions, labels, thresholds,
@ -1187,29 +935,11 @@ def streaming_precision_at_thresholds(predictions, labels, thresholds,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(name, 'precision_at_thresholds',
(predictions, labels, weights)):
values, update_ops = _streaming_confusion_matrix_at_thresholds(
predictions, labels, thresholds, weights, includes=('tp', 'fp'))
tp = values['tp']
fp = values['fp']
# Avoid division by zero.
epsilon = 1e-7
def compute_precision(name):
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
precision = compute_precision('value')
with ops.control_dependencies(update_ops.values()):
update_op = compute_precision('update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, precision)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return precision, update_op
return metrics.precision_at_thresholds(
thresholds=thresholds,
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_recall_at_thresholds(predictions, labels, thresholds,
@ -1253,29 +983,11 @@ def streaming_recall_at_thresholds(predictions, labels, thresholds,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(name, 'recall_at_thresholds',
(predictions, labels, weights)):
values, update_ops = _streaming_confusion_matrix_at_thresholds(
predictions, labels, thresholds, weights, includes=('tp', 'fn'))
tp = values['tp']
fn = values['fn']
# Avoid division by zero.
epsilon = 1e-7
def compute_recall(name):
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
recall = compute_recall('value')
with ops.control_dependencies(update_ops.values()):
update_op = compute_recall('update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, recall)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return recall, update_op
return metrics.recall_at_thresholds(
thresholds=thresholds,
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def _at_k_name(name, k=None, class_id=None):
@ -1413,25 +1125,11 @@ def streaming_sparse_recall_at_k(predictions,
`predictions`, or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
default_name = _at_k_name('recall', k, class_id=class_id)
with ops.name_scope(name, default_name, (predictions, labels)) as scope:
_, top_k_idx = nn.top_k(predictions, k)
top_k_idx = math_ops.to_int64(top_k_idx)
tp, tp_update = _streaming_sparse_true_positive_at_k(
predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
weights=weights)
fn, fn_update = _streaming_sparse_false_negative_at_k(
predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
weights=weights)
metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
update = math_ops.div(
tp_update, math_ops.add(tp_update, fn_update), name='update')
if metrics_collections:
ops.add_to_collections(metrics_collections, metric)
if updates_collections:
ops.add_to_collections(updates_collections, update)
return metric, update
return metrics.recall_at_k(
k=k, class_id=class_id,
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def _streaming_sparse_precision_at_k(top_k_idx,
@ -1575,19 +1273,11 @@ def streaming_sparse_precision_at_k(predictions,
`predictions`, or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
default_name = _at_k_name('precision', k, class_id=class_id)
with ops.name_scope(name, default_name,
(predictions, labels, weights)) as scope:
_, top_k_idx = nn.top_k(predictions, k)
return _streaming_sparse_precision_at_k(
top_k_idx=top_k_idx,
labels=labels,
k=k,
class_id=class_id,
weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections,
name=scope)
return metrics.sparse_precision_at_k(
k=k, class_id=class_id,
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
# TODO(ptucker): Validate range of values in labels?
@ -1918,50 +1608,10 @@ def streaming_sparse_average_precision_at_k(predictions,
update: `Operation` that increments variables appropriately, and whose
value matches `metric`.
"""
default_name = _at_k_name('average_precision', k)
with ops.name_scope(name, default_name, (predictions, labels)) as scope:
# Calculate per-example average precision, and apply weights.
average_precision = sparse_average_precision_at_k(
predictions=predictions, labels=labels, k=k)
if weights is not None:
weights = math_ops.to_double(weights)
average_precision = math_ops.mul(average_precision, weights)
# Create accumulation variables and update ops for max average precision and
# total average precision.
with ops.name_scope(None, 'max', (average_precision,)) as max_scope:
# `max` is the max possible precision. Since max for any row is 1.0:
# - For the unweighted case, this is just the number of rows.
# - For the weighted case, it's the sum of the weights broadcast across
# `average_precision` rows.
max_var = contrib_variables.local_variable(
array_ops.zeros([], dtype=dtypes.float64), name=max_scope)
if weights is None:
batch_max = math_ops.to_double(
array_ops.size(average_precision, name='batch_max'))
else:
# TODO(ptucker): More efficient way to broadcast?
broadcast_weights = math_ops.mul(
weights, array_ops.ones_like(average_precision),
name='broadcast_weights')
batch_max = math_ops.reduce_sum(broadcast_weights, name='batch_max')
max_update = state_ops.assign_add(max_var, batch_max, name='update')
with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
total_var = contrib_variables.local_variable(
array_ops.zeros([], dtype=dtypes.float64), name=total_scope)
batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
total_update = state_ops.assign_add(total_var, batch_total, name='update')
# Divide total by max to get mean, for both vars and the update ops.
mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
update = _safe_scalar_div(total_update, max_update, name=scope)
if metrics_collections:
ops.add_to_collections(metrics_collections, mean_average_precision)
if updates_collections:
ops.add_to_collections(updates_collections, update)
return mean_average_precision, update
return metrics.sparse_average_precision_at_k(
k=k, predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def _select_class_id(ids, selected_id):
@ -2329,12 +1979,10 @@ def streaming_mean_absolute_error(predictions, labels, weights=None,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
predictions, labels, weights = _remove_squeezable_dimensions(
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
absolute_errors = math_ops.abs(predictions - labels)
return streaming_mean(absolute_errors, weights, metrics_collections,
updates_collections, name or 'mean_absolute_error')
return metrics.mean_absolute_error(
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
@ -2382,19 +2030,10 @@ def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
predictions, labels, weights = _remove_squeezable_dimensions(
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
predictions, normalizer = tensor_util.remove_squeezable_dimensions(
predictions, normalizer)
predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
relative_errors = array_ops.where(
math_ops.equal(normalizer, 0.0),
array_ops.zeros_like(labels),
math_ops.div(math_ops.abs(labels - predictions), normalizer))
return streaming_mean(relative_errors, weights, metrics_collections,
updates_collections, name or 'mean_relative_error')
return metrics.mean_relative_error(
normalizer=normalizer, predictions=predictions, labels=labels,
weights=weights, metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_mean_squared_error(predictions, labels, weights=None,
@ -2441,12 +2080,10 @@ def streaming_mean_squared_error(predictions, labels, weights=None,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
predictions, labels, weights = _remove_squeezable_dimensions(
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
squared_error = math_ops.square(labels - predictions)
return streaming_mean(squared_error, weights, metrics_collections,
updates_collections, name or 'mean_squared_error')
return metrics.mean_squared_error(
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_root_mean_squared_error(predictions, labels, weights=None,
@ -2493,24 +2130,10 @@ def streaming_root_mean_squared_error(predictions, labels, weights=None,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
predictions, labels, weights = _remove_squeezable_dimensions(
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
value_tensor, update_op = streaming_mean_squared_error(
predictions, labels, weights, None, None,
name or 'root_mean_squared_error')
root_mean_squared_error = math_ops.sqrt(value_tensor)
with ops.control_dependencies([update_op]):
update_op = math_ops.sqrt(update_op)
if metrics_collections:
ops.add_to_collections(metrics_collections, root_mean_squared_error)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return root_mean_squared_error, update_op
return metrics.root_mean_squared_error(
predictions=predictions, labels=labels, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_covariance(predictions,
@ -2825,12 +2448,10 @@ def streaming_percentage_less(values, threshold, weights=None,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
"""
is_below_threshold = math_ops.to_float(math_ops.less(values, threshold))
return streaming_mean(is_below_threshold,
weights,
metrics_collections,
updates_collections,
name or 'percentage_below_threshold')
return metrics.percentage_below(
values=values, threshold=threshold, weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def streaming_mean_iou(predictions,
@ -2881,65 +2502,10 @@ def streaming_mean_iou(predictions,
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(
name, 'mean_iou', (predictions, labels, weights)):
# Check if shape is compatible.
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
# Local variable to accumulate the predictions in the confusion matrix.
cm_dtype = dtypes.int64 if weights is not None else dtypes.float64
total_cm = _create_local('total_confusion_matrix',
shape=[num_classes, num_classes], dtype=cm_dtype)
# Cast the type to int64 required by confusion_matrix_ops.
predictions = math_ops.to_int64(predictions)
labels = math_ops.to_int64(labels)
num_classes = math_ops.to_int64(num_classes)
# Flatten the input if its rank > 1.
predictions_rank = predictions.get_shape().ndims
if predictions_rank > 1:
predictions = array_ops.reshape(predictions, [-1])
labels_rank = labels.get_shape().ndims
if labels_rank > 1:
labels = array_ops.reshape(labels, [-1])
if weights is not None:
weights_rank = weights.get_shape().ndims
if weights_rank > 1:
weights = array_ops.reshape(weights, [-1])
# Accumulate the prediction to current confusion matrix.
current_cm = confusion_matrix_ops.confusion_matrix(
predictions, labels, num_classes, weights=weights, dtype=cm_dtype)
update_op = state_ops.assign_add(total_cm, current_cm)
def compute_mean_iou(name):
"""Compute the mean intersection-over-union via the confusion matrix."""
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
cm_diag = math_ops.to_float(array_ops.diag_part(total_cm))
denominator = sum_over_row + sum_over_col - cm_diag
# If the value of the denominator is 0, set it to 1 to avoid
# zero division.
denominator = array_ops.where(
math_ops.greater(denominator, 0),
denominator,
array_ops.ones_like(denominator))
iou = math_ops.div(cm_diag, denominator)
return math_ops.reduce_mean(iou, name=name)
mean_iou = compute_mean_iou('mean_iou')
if metrics_collections:
ops.add_to_collections(metrics_collections, mean_iou)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return mean_iou, update_op
return metrics.mean_iou(
num_classes=num_classes, predictions=predictions, labels=labels,
weights=weights, metrics_collections=metrics_collections,
updates_collections=updates_collections, name=name)
def _next_array_size(required_size, growth_factor=1.5):

View File

@ -39,6 +39,7 @@ py_library(
":platform",
":platform_test",
":summary",
":metrics",
":layers",
":training",
":ops",
@ -1312,6 +1313,39 @@ py_library(
],
)
py_library(
name = "confusion_matrix",
srcs = ["ops/confusion_matrix.py"],
srcs_version = "PY2AND3",
deps = [
":array_ops",
":control_flow_ops",
":framework",
":math_ops",
":sparse_ops",
],
)
py_library(
name = "metrics",
srcs = ["ops/metrics.py"],
srcs_version = "PY2AND3",
deps = [
":array_ops",
":check_ops",
":confusion_matrix",
":control_flow_ops",
":framework",
":math_ops",
":nn",
":sets",
":sparse_ops",
":state_ops",
":variable_scope",
":variables",
],
)
py_library(
name = "special_math_ops",
srcs = ["ops/special_math_ops.py"],
@ -1334,6 +1368,7 @@ py_library(
":array_ops",
":check_ops",
":clip_ops",
":confusion_matrix",
":control_flow_ops",
":data_flow_grad",
":data_flow_ops",

View File

@ -83,6 +83,7 @@ from tensorflow.python.ops.standard_ops import *
# Bring in subpackages.
from tensorflow.python.layers import layers
from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn
from tensorflow.python.ops import resources
from tensorflow.python.ops import sdca_ops as sdca
@ -118,6 +119,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import framework_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import confusion_matrix as confusion_matrix_m
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import histogram_ops
@ -220,6 +222,7 @@ _allowed_symbols.extend([
'image',
'logging',
'losses',
'metrics',
'newaxis',
'nn',
'python_io',
@ -246,10 +249,10 @@ _allowed_symbols.extend([
# referenced in the whitelist.
remove_undocumented(__name__, _allowed_symbols,
[framework_lib, array_ops, client_lib, check_ops,
compat, constant_op, control_flow_ops, functional_ops,
histogram_ops, io_ops, losses, math_ops, nn,
resource_loader, resources, sets, script_ops, session_ops,
sparse_ops, state_ops, string_ops, summary,
compat, constant_op, control_flow_ops, confusion_matrix_m,
functional_ops, histogram_ops, io_ops, losses, math_ops,
metrics, nn, resource_loader, resources, sets, script_ops,
session_ops, sparse_ops, state_ops, string_ops, summary,
tensor_array_ops, train, layers])
# Special dunders that we choose to export:

View File

@ -260,6 +260,7 @@ EXCLUDE = frozenset(["tf.contrib.learn.monitors.NanLossDuringTrainingError",
"tf.contrib.framework.get_global_step",
"tf.contrib.learn.NanLossDuringTrainingError",
"tf.contrib.layers.stack",
"tf.confusion_matrix",
"tf.nn.rnn_cell.RNNCell",
"tf.nn.rnn_cell.BasicRNNCell",
"tf.nn.rnn_cell.BasicLSTMCell",

View File

@ -1385,6 +1385,21 @@ tf_py_test(
additional_deps = ["//tensorflow:tensorflow_py"],
)
tf_py_test(
name = "metrics_test",
size = "small",
srcs = ["metrics_test.py"],
additional_deps = ["//tensorflow:tensorflow_py"],
shard_count = 3,
)
tf_py_test(
name = "confusion_matrix_test",
size = "small",
srcs = ["confusion_matrix_test.py"],
additional_deps = ["//tensorflow:tensorflow_py"],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -28,8 +28,8 @@ class ConfusionMatrixTest(tf.test.TestCase):
def _testConfMatrix(self, predictions, labels, truth, weights=None):
with self.test_session():
dtype = predictions.dtype
ans = tf.contrib.metrics.confusion_matrix(
predictions, labels, dtype=dtype, weights=weights)
ans = tf.confusion_matrix(
labels, predictions, dtype=dtype, weights=weights)
tf_ans = ans.eval()
self.assertAllClose(tf_ans, truth, atol=1e-10)
self.assertEqual(tf_ans.dtype, dtype)
@ -69,8 +69,8 @@ class ConfusionMatrixTest(tf.test.TestCase):
lab = tf.concat(0, [tf.zeros([20], dtype=tf_dtype),
tf.ones([20], dtype=tf_dtype)])
cm = tf.contrib.metrics.confusion_matrix(
data, lab, dtype=tf_dtype, num_classes=2)
cm = tf.confusion_matrix(
lab, data, dtype=tf_dtype, num_classes=2)
d, l, cm_out = sess.run([data, lab, cm], {m_neg: 0.0,
m_pos: 1.0,
@ -157,28 +157,28 @@ class ConfusionMatrixTest(tf.test.TestCase):
predictions = np.asarray([[1, 2, 3]])
labels = np.asarray([1, 2, 3])
self.assertRaisesRegexp(ValueError, "an not squeeze dim",
tf.contrib.metrics.confusion_matrix, predictions,
labels)
tf.confusion_matrix,
predictions, labels)
predictions = np.asarray([1, 2, 3])
labels = np.asarray([[1, 2, 3]])
self.assertRaisesRegexp(ValueError, "an not squeeze dim",
tf.contrib.metrics.confusion_matrix, predictions,
labels)
tf.confusion_matrix,
predictions, labels)
def testInputDifferentSize(self):
predictions = np.asarray([1, 2, 3])
labels = np.asarray([1, 2])
self.assertRaisesRegexp(ValueError, "must be equal",
tf.contrib.metrics.confusion_matrix, predictions,
labels)
tf.confusion_matrix,
predictions, labels)
def testOutputIsInt32(self):
predictions = np.arange(2)
labels = np.arange(2)
with self.test_session():
cm = tf.contrib.metrics.confusion_matrix(
predictions, labels, dtype=dtypes.int32)
cm = tf.confusion_matrix(
labels, predictions, dtype=dtypes.int32)
tf_cm = cm.eval()
self.assertEqual(tf_cm.dtype, np.int32)
@ -186,8 +186,8 @@ class ConfusionMatrixTest(tf.test.TestCase):
predictions = np.arange(2)
labels = np.arange(2)
with self.test_session():
cm = tf.contrib.metrics.confusion_matrix(
predictions, labels, dtype=dtypes.int64)
cm = tf.confusion_matrix(
labels, predictions, dtype=dtypes.int64)
tf_cm = cm.eval()
self.assertEqual(tf_cm.dtype, np.int64)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,163 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Confusion matrix related utilities.
@@remove_squeezable_dimensions
@@confusion_matrix
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
def remove_squeezable_dimensions(labels, predictions, name=None):
"""Squeeze last dim if ranks of `predictions` and `labels` differ by 1.
This will use static shape if available. Otherwise, it will add graph
operations, which could result in a performance hit.
Args:
labels: Label values, a `Tensor` whose dimensions match `predictions`.
predictions: Predicted values, a `Tensor` of arbitrary dimensions.
name: Name of the op.
Returns:
Tuple of `labels` and `predictions`, possibly with last dim squeezed.
"""
with ops.name_scope(name, 'remove_squeezable_dimensions',
[labels, predictions]):
predictions = ops.convert_to_tensor(predictions)
labels = ops.convert_to_tensor(labels)
predictions_shape = predictions.get_shape()
predictions_rank = predictions_shape.ndims
labels_shape = labels.get_shape()
labels_rank = labels_shape.ndims
if (labels_rank is not None) and (predictions_rank is not None):
# Use static rank.
rank_diff = predictions_rank - labels_rank
if rank_diff == -1:
labels = array_ops.squeeze(labels, [-1])
elif rank_diff == 1:
predictions = array_ops.squeeze(predictions, [-1])
return labels, predictions
# Use dynamic rank.
rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
if (predictions_rank is None) or (
predictions_shape.dims[-1].is_compatible_with(1)):
predictions = control_flow_ops.cond(
math_ops.equal(1, rank_diff),
lambda: array_ops.squeeze(predictions, [-1]),
lambda: predictions)
if (labels_rank is None) or (
labels_shape.dims[-1].is_compatible_with(1)):
labels = control_flow_ops.cond(
math_ops.equal(-1, rank_diff),
lambda: array_ops.squeeze(labels, [-1]),
lambda: labels)
return labels, predictions
def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32,
name=None, weights=None):
"""Computes the confusion matrix from predictions and labels.
Calculate the Confusion Matrix for a pair of prediction and
label 1-D int arrays.
The matrix rows represent the prediction labels and the columns
represents the real labels. The confusion matrix is always a 2-D array
of shape `[n, n]`, where `n` is the number of valid labels for a given
classification task. Both prediction and labels must be 1-D arrays of
the same shape in order for this function to work.
If `num_classes` is None, then `num_classes` will be set to the one plus
the maximum value in either predictions or labels.
Class labels are expected to start at 0. E.g., if `num_classes` was
three, then the possible labels would be `[0, 1, 2]`.
If `weights` is not `None`, then each prediction contributes its
corresponding weight to the total value of the confusion matrix cell.
For example:
```python
tf.contrib.metrics.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
[[0 0 0 0 0]
[0 0 1 0 0]
[0 0 1 0 0]
[0 0 0 0 0]
[0 0 0 0 1]]
```
Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
resulting in a 5x5 confusion matrix.
Args:
labels: A 1-D representing the real labels for the classification task.
predictions: A 1-D array representing the predictions for a given
classification.
num_classes: The possible number of labels the classification task can
have. If this value is not provided, it will be calculated
using both predictions and labels array.
dtype: Data type of the confusion matrix.
name: Scope name.
weights: An optional `Tensor` whose shape matches `predictions`.
Returns:
A k X k matrix representing the confusion matrix, where k is the number of
possible labels in the classification task.
Raises:
ValueError: If both predictions and labels are not 1-D vectors and have
mismatched shapes, or if `weights` is not `None` and its shape doesn't
match `predictions`.
"""
with ops.name_scope(name, 'confusion_matrix',
[predictions, labels, num_classes]) as name:
labels, predictions = remove_squeezable_dimensions(
ops.convert_to_tensor(labels, name='labels'),
ops.convert_to_tensor(
predictions, name='predictions'))
predictions = math_ops.cast(predictions, dtypes.int64)
labels = math_ops.cast(labels, dtypes.int64)
if num_classes is None:
num_classes = math_ops.maximum(math_ops.reduce_max(predictions),
math_ops.reduce_max(labels)) + 1
if weights is not None:
predictions.get_shape().assert_is_compatible_with(weights.get_shape())
weights = math_ops.cast(weights, dtype)
shape = array_ops.pack([num_classes, num_classes])
indices = array_ops.transpose(array_ops.pack([predictions, labels]))
values = (array_ops.ones_like(predictions, dtype)
if weights is None else weights)
cm_sparse = sparse_tensor.SparseTensor(
indices=indices, values=values, shape=math_ops.to_int64(shape))
zero_matrix = array_ops.zeros(math_ops.to_int32(shape), dtype)
return sparse_ops.sparse_add(zero_matrix, cm_sparse)

File diff suppressed because it is too large Load Diff

View File

@ -39,6 +39,7 @@ from tensorflow.python.ops.check_ops import *
from tensorflow.python.ops.clip_ops import *
from tensorflow.python.ops.special_math_ops import *
# TODO(vrv): Switch to import * once we're okay with exposing the module.
from tensorflow.python.ops.confusion_matrix import confusion_matrix
from tensorflow.python.ops.control_flow_ops import Assert
from tensorflow.python.ops.control_flow_ops import group
from tensorflow.python.ops.control_flow_ops import no_op
@ -91,6 +92,7 @@ from tensorflow.python.framework import constant_op as _constant_op
from tensorflow.python.ops import array_ops as _array_ops
from tensorflow.python.ops import check_ops as _check_ops
from tensorflow.python.ops import clip_ops as _clip_ops
from tensorflow.python.ops import confusion_matrix as _confusion_matrix
from tensorflow.python.ops import control_flow_ops as _control_flow_ops
from tensorflow.python.ops import data_flow_ops as _data_flow_ops
from tensorflow.python.ops import functional_ops as _functional_ops
@ -244,6 +246,7 @@ _allowed_symbols_misc = [
"parse_single_sequence_example",
"serialize_many_sparse",
"serialize_sparse",
"confusion_matrix",
]
_allowed_symbols = (_allowed_symbols_array_ops +
@ -262,6 +265,7 @@ remove_undocumented(__name__, _allowed_symbols,
_array_ops,
_check_ops,
_clip_ops,
_confusion_matrix,
_control_flow_ops,
_constant_op,
_data_flow_ops,