Moves most metrics from contrib into core.
Change: 140914784
This commit is contained in:
parent
ccf6cf533c
commit
0e5015bb7d
@ -33,19 +33,6 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "confusion_matrix_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["python/kernel_tests/confusion_matrix_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":metrics_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "histogram_ops_test",
|
||||
size = "medium",
|
||||
|
@ -133,7 +133,6 @@ labels and predictions tensors and results in a weighted average of the metric.
|
||||
@@auc_using_histogram
|
||||
|
||||
@@accuracy
|
||||
@@confusion_matrix
|
||||
|
||||
@@aggregate_metrics
|
||||
@@aggregate_metric_map
|
||||
|
@ -18,93 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework import tensor_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import confusion_matrix as cm
|
||||
|
||||
|
||||
def confusion_matrix(predictions, labels, num_classes=None, dtype=dtypes.int32,
|
||||
def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32,
|
||||
name=None, weights=None):
|
||||
"""Computes the confusion matrix from predictions and labels.
|
||||
|
||||
Calculate the Confusion Matrix for a pair of prediction and
|
||||
label 1-D int arrays.
|
||||
|
||||
The matrix rows represent the prediction labels and the columns
|
||||
represents the real labels. The confusion matrix is always a 2-D array
|
||||
of shape `[n, n]`, where `n` is the number of valid labels for a given
|
||||
classification task. Both prediction and labels must be 1-D arrays of
|
||||
the same shape in order for this function to work.
|
||||
|
||||
If `num_classes` is None, then `num_classes` will be set to the one plus
|
||||
the maximum value in either predictions or labels.
|
||||
Class labels are expected to start at 0. E.g., if `num_classes` was
|
||||
three, then the possible labels would be `[0, 1, 2]`.
|
||||
|
||||
If `weights` is not `None`, then each prediction contributes its
|
||||
corresponding weight to the total value of the confusion matrix cell.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
tf.contrib.metrics.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
|
||||
[[0 0 0 0 0]
|
||||
[0 0 1 0 0]
|
||||
[0 0 1 0 0]
|
||||
[0 0 0 0 0]
|
||||
[0 0 0 0 1]]
|
||||
```
|
||||
|
||||
Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
|
||||
resulting in a 5x5 confusion matrix.
|
||||
|
||||
Args:
|
||||
predictions: A 1-D array representing the predictions for a given
|
||||
classification.
|
||||
labels: A 1-D representing the real labels for the classification task.
|
||||
num_classes: The possible number of labels the classification task can
|
||||
have. If this value is not provided, it will be calculated
|
||||
using both predictions and labels array.
|
||||
dtype: Data type of the confusion matrix.
|
||||
name: Scope name.
|
||||
weights: An optional `Tensor` whose shape matches `predictions`.
|
||||
|
||||
Returns:
|
||||
A k X k matrix representing the confusion matrix, where k is the number of
|
||||
possible labels in the classification task.
|
||||
|
||||
Raises:
|
||||
ValueError: If both predictions and labels are not 1-D vectors and have
|
||||
mismatched shapes, or if `weights` is not `None` and its shape doesn't
|
||||
match `predictions`.
|
||||
"""
|
||||
with ops.name_scope(name, 'confusion_matrix',
|
||||
[predictions, labels, num_classes]) as name:
|
||||
predictions, labels = tensor_util.remove_squeezable_dimensions(
|
||||
ops.convert_to_tensor(
|
||||
predictions, name='predictions'),
|
||||
ops.convert_to_tensor(labels, name='labels'))
|
||||
predictions = math_ops.cast(predictions, dtypes.int64)
|
||||
labels = math_ops.cast(labels, dtypes.int64)
|
||||
|
||||
if num_classes is None:
|
||||
num_classes = math_ops.maximum(math_ops.reduce_max(predictions),
|
||||
math_ops.reduce_max(labels)) + 1
|
||||
|
||||
if weights is not None:
|
||||
predictions.get_shape().assert_is_compatible_with(weights.get_shape())
|
||||
weights = math_ops.cast(weights, dtype)
|
||||
|
||||
shape = array_ops.pack([num_classes, num_classes])
|
||||
indices = array_ops.transpose(array_ops.pack([predictions, labels]))
|
||||
values = (array_ops.ones_like(predictions, dtype)
|
||||
if weights is None else weights)
|
||||
cm_sparse = sparse_tensor.SparseTensor(
|
||||
indices=indices, values=values, shape=math_ops.to_int64(shape))
|
||||
zero_matrix = array_ops.zeros(math_ops.to_int32(shape), dtype)
|
||||
|
||||
return sparse_ops.sparse_add(zero_matrix, cm_sparse)
|
||||
"""Deprecated. Use tf.confusion_matrix instead."""
|
||||
return cm.confusion_matrix(labels=labels, predictions=predictions,
|
||||
num_classes=num_classes, dtype=dtype, name=name,
|
||||
weights=weights)
|
||||
|
@ -25,7 +25,6 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.framework import deprecated
|
||||
from tensorflow.contrib.framework import tensor_util
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
|
||||
from tensorflow.contrib.metrics.python.ops import set_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -34,6 +33,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import metrics
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
@ -178,16 +178,10 @@ def streaming_true_positives(predictions, labels, weights=None,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
name, 'true_positives', (predictions, labels, weights)):
|
||||
|
||||
predictions = ops.convert_to_tensor(predictions)
|
||||
labels = ops.convert_to_tensor(labels)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
is_true_positive = math_ops.logical_and(math_ops.equal(labels, 1),
|
||||
math_ops.equal(predictions, 1))
|
||||
return _count_condition(is_true_positive, weights, metrics_collections,
|
||||
updates_collections)
|
||||
return metrics.true_positives(
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_true_negatives(predictions, labels, weights=None,
|
||||
@ -262,16 +256,10 @@ def streaming_false_positives(predictions, labels, weights=None,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
name, 'false_positives', (predictions, labels, weights)):
|
||||
|
||||
predictions = ops.convert_to_tensor(predictions)
|
||||
labels = ops.convert_to_tensor(labels)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
is_false_positive = math_ops.logical_and(math_ops.equal(labels, 0),
|
||||
math_ops.equal(predictions, 1))
|
||||
return _count_condition(is_false_positive, weights, metrics_collections,
|
||||
updates_collections)
|
||||
return metrics.false_positives(
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_false_negatives(predictions, labels, weights=None,
|
||||
@ -303,16 +291,10 @@ def streaming_false_negatives(predictions, labels, weights=None,
|
||||
or if either `metrics_collections` or `updates_collections` are not a list
|
||||
or tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
name, 'false_negatives', (predictions, labels, weights)):
|
||||
|
||||
predictions = ops.convert_to_tensor(predictions)
|
||||
labels = ops.convert_to_tensor(labels)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
is_false_negative = math_ops.logical_and(math_ops.equal(labels, 1),
|
||||
math_ops.equal(predictions, 0))
|
||||
return _count_condition(is_false_negative, weights, metrics_collections,
|
||||
updates_collections)
|
||||
return metrics.false_negatives(
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def _broadcast_weights(weights, values):
|
||||
@ -376,33 +358,9 @@ def streaming_mean(values, weights=None, metrics_collections=None,
|
||||
or if either `metrics_collections` or `updates_collections` are not a list
|
||||
or tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(name, 'mean', (values, weights)):
|
||||
values = math_ops.to_float(values)
|
||||
|
||||
total = _create_local('total', shape=[])
|
||||
count = _create_local('count', shape=[])
|
||||
|
||||
if weights is not None:
|
||||
weights = math_ops.to_float(weights)
|
||||
values = math_ops.mul(values, weights)
|
||||
num_values = math_ops.reduce_sum(_broadcast_weights(weights, values))
|
||||
else:
|
||||
num_values = math_ops.to_float(array_ops.size(values))
|
||||
|
||||
total_compute_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
|
||||
count_compute_op = state_ops.assign_add(count, num_values)
|
||||
|
||||
mean = _safe_div(total, count, 'value')
|
||||
with ops.control_dependencies([total_compute_op, count_compute_op]):
|
||||
update_op = _safe_div(total, count, 'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return mean, update_op
|
||||
return metrics.mean(
|
||||
values=values, weights=weights, metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_mean_tensor(values, weights=None, metrics_collections=None,
|
||||
@ -445,36 +403,9 @@ def streaming_mean_tensor(values, weights=None, metrics_collections=None,
|
||||
or if either `metrics_collections` or `updates_collections` are not a list
|
||||
or tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(name, 'mean', (values, weights)):
|
||||
total = _create_local('total_tensor', shape=values.get_shape())
|
||||
count = _create_local('count_tensor', shape=values.get_shape())
|
||||
|
||||
num_values = array_ops.ones_like(values)
|
||||
if weights is not None:
|
||||
weights = math_ops.to_float(weights)
|
||||
values = math_ops.mul(values, weights)
|
||||
num_values = math_ops.mul(num_values, weights)
|
||||
|
||||
total_compute_op = state_ops.assign_add(total, values)
|
||||
count_compute_op = state_ops.assign_add(count, num_values)
|
||||
|
||||
def compute_mean(total, count, name):
|
||||
non_zero_count = math_ops.maximum(count,
|
||||
array_ops.ones_like(count),
|
||||
name=name)
|
||||
return math_ops.truediv(total, non_zero_count, name=name)
|
||||
|
||||
mean = compute_mean(total, count, 'value')
|
||||
with ops.control_dependencies([total_compute_op, count_compute_op]):
|
||||
update_op = compute_mean(total, count, 'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return mean, update_op
|
||||
return metrics.mean_tensor(
|
||||
values=values, weights=weights, metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_accuracy(predictions, labels, weights=None,
|
||||
@ -520,14 +451,10 @@ def streaming_accuracy(predictions, labels, weights=None,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
predictions, labels, weights = _remove_squeezable_dimensions(
|
||||
predictions, labels, weights=weights)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
if labels.dtype != predictions.dtype:
|
||||
predictions = math_ops.cast(predictions, labels.dtype)
|
||||
is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
|
||||
return streaming_mean(is_correct, weights, metrics_collections,
|
||||
updates_collections, name or 'accuracy')
|
||||
return metrics.accuracy(
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_precision(predictions, labels, weights=None,
|
||||
@ -572,39 +499,10 @@ def streaming_precision(predictions, labels, weights=None,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
name, 'precision', (predictions, labels, weights)):
|
||||
|
||||
predictions, labels, weights = _remove_squeezable_dimensions(
|
||||
predictions, labels, weights)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
|
||||
true_positives, true_positives_update_op = streaming_true_positives(
|
||||
predictions, labels, weights, metrics_collections=None,
|
||||
updates_collections=None, name=None)
|
||||
false_positives, false_positives_update_op = streaming_false_positives(
|
||||
predictions, labels, weights, metrics_collections=None,
|
||||
updates_collections=None, name=None)
|
||||
|
||||
def compute_precision(name):
|
||||
return array_ops.where(
|
||||
math_ops.greater(true_positives + false_positives, 0),
|
||||
math_ops.div(true_positives, true_positives + false_positives),
|
||||
0,
|
||||
name)
|
||||
|
||||
precision = compute_precision('value')
|
||||
with ops.control_dependencies([true_positives_update_op,
|
||||
false_positives_update_op]):
|
||||
update_op = compute_precision('update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, precision)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return precision, update_op
|
||||
return metrics.precision(
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_recall(predictions, labels, weights=None,
|
||||
@ -647,38 +545,10 @@ def streaming_recall(predictions, labels, weights=None,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
name, 'recall', (predictions, labels, weights)):
|
||||
predictions, labels, weights = _remove_squeezable_dimensions(
|
||||
predictions, labels, weights)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
|
||||
true_positives, true_positives_update_op = streaming_true_positives(
|
||||
predictions, labels, weights, metrics_collections=None,
|
||||
updates_collections=None, name=None)
|
||||
false_negatives, false_negatives_update_op = streaming_false_negatives(
|
||||
predictions, labels, weights, metrics_collections=None,
|
||||
updates_collections=None, name=None)
|
||||
|
||||
def compute_recall(true_positives, false_negatives, name):
|
||||
return array_ops.where(
|
||||
math_ops.greater(true_positives + false_negatives, 0),
|
||||
math_ops.div(true_positives, true_positives + false_negatives),
|
||||
0,
|
||||
name)
|
||||
|
||||
recall = compute_recall(true_positives, false_negatives, 'value')
|
||||
with ops.control_dependencies([true_positives_update_op,
|
||||
false_negatives_update_op]):
|
||||
update_op = compute_recall(true_positives, false_negatives, 'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, recall)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return recall, update_op
|
||||
return metrics.recall(
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def _streaming_confusion_matrix_at_thresholds(
|
||||
@ -903,50 +773,10 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
name, 'auc', (predictions, labels, weights)):
|
||||
if curve != 'ROC' and curve != 'PR':
|
||||
raise ValueError('curve must be either ROC or PR, %s unknown' %
|
||||
(curve))
|
||||
kepsilon = 1e-7 # to account for floating point imprecisions
|
||||
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
|
||||
for i in range(num_thresholds-2)]
|
||||
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
|
||||
|
||||
values, update_ops = _streaming_confusion_matrix_at_thresholds(
|
||||
predictions, labels, thresholds, weights)
|
||||
|
||||
# Add epsilons to avoid dividing by 0.
|
||||
epsilon = 1.0e-6
|
||||
def compute_auc(tp, fn, tn, fp, name):
|
||||
"""Computes the roc-auc or pr-auc based on confusion counts."""
|
||||
recall = math_ops.div(tp + epsilon, tp + fn + epsilon)
|
||||
if curve == 'ROC':
|
||||
fp_rate = math_ops.div(fp, fp + tn + epsilon)
|
||||
x = fp_rate
|
||||
y = recall
|
||||
else: # curve == 'PR'.
|
||||
precision = math_ops.div(tp + epsilon, tp + fp + epsilon)
|
||||
x = recall
|
||||
y = precision
|
||||
return math_ops.reduce_sum(math_ops.mul(
|
||||
x[:num_thresholds - 1] - x[1:],
|
||||
(y[:num_thresholds - 1] + y[1:]) / 2.), name=name)
|
||||
|
||||
# sum up the areas of all the trapeziums
|
||||
auc = compute_auc(
|
||||
values['tp'], values['fn'], values['tn'], values['fp'], 'value')
|
||||
update_op = compute_auc(
|
||||
update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'],
|
||||
'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, auc)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return auc, update_op
|
||||
return metrics.auc(
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections, num_thresholds=num_thresholds,
|
||||
curve=curve, updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_specificity_at_sensitivity(
|
||||
@ -998,60 +828,11 @@ def streaming_specificity_at_sensitivity(
|
||||
`sensitivity` is not between 0 and 1, or if either `metrics_collections`
|
||||
or `updates_collections` are not a list or tuple.
|
||||
"""
|
||||
if sensitivity < 0 or sensitivity > 1:
|
||||
raise ValueError('`sensitivity` must be in the range [0, 1].')
|
||||
|
||||
with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
|
||||
(predictions, labels, weights)):
|
||||
kepsilon = 1e-7 # to account for floating point imprecisions
|
||||
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
|
||||
for i in range(num_thresholds-2)]
|
||||
thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
|
||||
|
||||
values, update_ops = _streaming_confusion_matrix_at_thresholds(
|
||||
predictions, labels, thresholds, weights)
|
||||
tp = values['tp']
|
||||
fn = values['fn']
|
||||
tn = values['tn']
|
||||
fp = values['fp']
|
||||
|
||||
def compute_specificity_at_sensitivity(name):
|
||||
"""Computes the specificity at the given sensitivity.
|
||||
|
||||
Args:
|
||||
name: The name of the operation.
|
||||
|
||||
Returns:
|
||||
The specificity using the aggregated values.
|
||||
"""
|
||||
sensitivities = math_ops.div(tp, tp + fn + kepsilon)
|
||||
|
||||
# We'll need to use this trick until tf.argmax allows us to specify
|
||||
# whether we should use the first or last index in case of ties.
|
||||
min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
|
||||
indices_at_minval = math_ops.equal(
|
||||
math_ops.abs(sensitivities - sensitivity), min_val)
|
||||
indices_at_minval = math_ops.to_int64(indices_at_minval)
|
||||
indices_at_minval = math_ops.cumsum(indices_at_minval)
|
||||
tf_index = math_ops.argmax(indices_at_minval, 0)
|
||||
tf_index = math_ops.cast(tf_index, dtypes.int32)
|
||||
|
||||
# Now, we have the implicit threshold, so compute the specificity:
|
||||
return math_ops.div(tn[tf_index],
|
||||
tn[tf_index] + fp[tf_index] + kepsilon,
|
||||
name)
|
||||
|
||||
specificity = compute_specificity_at_sensitivity('value')
|
||||
with ops.control_dependencies(update_ops.values()):
|
||||
update_op = compute_specificity_at_sensitivity('update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, specificity)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return specificity, update_op
|
||||
return metrics.specificity_at_sensitivity(
|
||||
sensitivity=sensitivity, num_thresholds=num_thresholds,
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_sensitivity_at_specificity(
|
||||
@ -1103,44 +884,11 @@ def streaming_sensitivity_at_specificity(
|
||||
`specificity` is not between 0 and 1, or if either `metrics_collections`
|
||||
or `updates_collections` are not a list or tuple.
|
||||
"""
|
||||
if specificity < 0 or specificity > 1:
|
||||
raise ValueError('`specificity` must be in the range [0, 1].')
|
||||
|
||||
with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
|
||||
(predictions, labels, weights)):
|
||||
kepsilon = 1e-7 # to account for floating point imprecisions
|
||||
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
|
||||
for i in range(num_thresholds-2)]
|
||||
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
|
||||
|
||||
values, update_ops = _streaming_confusion_matrix_at_thresholds(
|
||||
predictions, labels, thresholds, weights)
|
||||
tp = values['tp']
|
||||
fn = values['fn']
|
||||
tn = values['tn']
|
||||
fp = values['fp']
|
||||
|
||||
def compute_sensitivity_at_specificity(name):
|
||||
specificities = math_ops.div(tn, tn + fp + kepsilon)
|
||||
tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
|
||||
tf_index = math_ops.cast(tf_index, dtypes.int32)
|
||||
|
||||
# Now, we have the implicit threshold, so compute the sensitivity:
|
||||
return math_ops.div(tp[tf_index],
|
||||
tp[tf_index] + fn[tf_index] + kepsilon,
|
||||
name)
|
||||
|
||||
sensitivity = compute_sensitivity_at_specificity('value')
|
||||
with ops.control_dependencies(update_ops.values()):
|
||||
update_op = compute_sensitivity_at_specificity('update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, sensitivity)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return sensitivity, update_op
|
||||
return metrics.sensitivity_at_specificity(
|
||||
specificity=specificity, num_thresholds=num_thresholds,
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_precision_at_thresholds(predictions, labels, thresholds,
|
||||
@ -1187,29 +935,11 @@ def streaming_precision_at_thresholds(predictions, labels, thresholds,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(name, 'precision_at_thresholds',
|
||||
(predictions, labels, weights)):
|
||||
values, update_ops = _streaming_confusion_matrix_at_thresholds(
|
||||
predictions, labels, thresholds, weights, includes=('tp', 'fp'))
|
||||
tp = values['tp']
|
||||
fp = values['fp']
|
||||
|
||||
# Avoid division by zero.
|
||||
epsilon = 1e-7
|
||||
def compute_precision(name):
|
||||
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
|
||||
|
||||
precision = compute_precision('value')
|
||||
with ops.control_dependencies(update_ops.values()):
|
||||
update_op = compute_precision('update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, precision)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return precision, update_op
|
||||
return metrics.precision_at_thresholds(
|
||||
thresholds=thresholds,
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_recall_at_thresholds(predictions, labels, thresholds,
|
||||
@ -1253,29 +983,11 @@ def streaming_recall_at_thresholds(predictions, labels, thresholds,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(name, 'recall_at_thresholds',
|
||||
(predictions, labels, weights)):
|
||||
values, update_ops = _streaming_confusion_matrix_at_thresholds(
|
||||
predictions, labels, thresholds, weights, includes=('tp', 'fn'))
|
||||
tp = values['tp']
|
||||
fn = values['fn']
|
||||
|
||||
# Avoid division by zero.
|
||||
epsilon = 1e-7
|
||||
def compute_recall(name):
|
||||
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
|
||||
|
||||
recall = compute_recall('value')
|
||||
with ops.control_dependencies(update_ops.values()):
|
||||
update_op = compute_recall('update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, recall)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return recall, update_op
|
||||
return metrics.recall_at_thresholds(
|
||||
thresholds=thresholds,
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def _at_k_name(name, k=None, class_id=None):
|
||||
@ -1413,25 +1125,11 @@ def streaming_sparse_recall_at_k(predictions,
|
||||
`predictions`, or if either `metrics_collections` or `updates_collections`
|
||||
are not a list or tuple.
|
||||
"""
|
||||
default_name = _at_k_name('recall', k, class_id=class_id)
|
||||
with ops.name_scope(name, default_name, (predictions, labels)) as scope:
|
||||
_, top_k_idx = nn.top_k(predictions, k)
|
||||
top_k_idx = math_ops.to_int64(top_k_idx)
|
||||
tp, tp_update = _streaming_sparse_true_positive_at_k(
|
||||
predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
|
||||
weights=weights)
|
||||
fn, fn_update = _streaming_sparse_false_negative_at_k(
|
||||
predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
|
||||
weights=weights)
|
||||
|
||||
metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
|
||||
update = math_ops.div(
|
||||
tp_update, math_ops.add(tp_update, fn_update), name='update')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, metric)
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update)
|
||||
return metric, update
|
||||
return metrics.recall_at_k(
|
||||
k=k, class_id=class_id,
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def _streaming_sparse_precision_at_k(top_k_idx,
|
||||
@ -1575,19 +1273,11 @@ def streaming_sparse_precision_at_k(predictions,
|
||||
`predictions`, or if either `metrics_collections` or `updates_collections`
|
||||
are not a list or tuple.
|
||||
"""
|
||||
default_name = _at_k_name('precision', k, class_id=class_id)
|
||||
with ops.name_scope(name, default_name,
|
||||
(predictions, labels, weights)) as scope:
|
||||
_, top_k_idx = nn.top_k(predictions, k)
|
||||
return _streaming_sparse_precision_at_k(
|
||||
top_k_idx=top_k_idx,
|
||||
labels=labels,
|
||||
k=k,
|
||||
class_id=class_id,
|
||||
weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections,
|
||||
name=scope)
|
||||
return metrics.sparse_precision_at_k(
|
||||
k=k, class_id=class_id,
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
# TODO(ptucker): Validate range of values in labels?
|
||||
@ -1918,50 +1608,10 @@ def streaming_sparse_average_precision_at_k(predictions,
|
||||
update: `Operation` that increments variables appropriately, and whose
|
||||
value matches `metric`.
|
||||
"""
|
||||
default_name = _at_k_name('average_precision', k)
|
||||
with ops.name_scope(name, default_name, (predictions, labels)) as scope:
|
||||
# Calculate per-example average precision, and apply weights.
|
||||
average_precision = sparse_average_precision_at_k(
|
||||
predictions=predictions, labels=labels, k=k)
|
||||
if weights is not None:
|
||||
weights = math_ops.to_double(weights)
|
||||
average_precision = math_ops.mul(average_precision, weights)
|
||||
|
||||
# Create accumulation variables and update ops for max average precision and
|
||||
# total average precision.
|
||||
with ops.name_scope(None, 'max', (average_precision,)) as max_scope:
|
||||
# `max` is the max possible precision. Since max for any row is 1.0:
|
||||
# - For the unweighted case, this is just the number of rows.
|
||||
# - For the weighted case, it's the sum of the weights broadcast across
|
||||
# `average_precision` rows.
|
||||
max_var = contrib_variables.local_variable(
|
||||
array_ops.zeros([], dtype=dtypes.float64), name=max_scope)
|
||||
if weights is None:
|
||||
batch_max = math_ops.to_double(
|
||||
array_ops.size(average_precision, name='batch_max'))
|
||||
else:
|
||||
# TODO(ptucker): More efficient way to broadcast?
|
||||
broadcast_weights = math_ops.mul(
|
||||
weights, array_ops.ones_like(average_precision),
|
||||
name='broadcast_weights')
|
||||
batch_max = math_ops.reduce_sum(broadcast_weights, name='batch_max')
|
||||
max_update = state_ops.assign_add(max_var, batch_max, name='update')
|
||||
with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
|
||||
total_var = contrib_variables.local_variable(
|
||||
array_ops.zeros([], dtype=dtypes.float64), name=total_scope)
|
||||
batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
|
||||
total_update = state_ops.assign_add(total_var, batch_total, name='update')
|
||||
|
||||
# Divide total by max to get mean, for both vars and the update ops.
|
||||
mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
|
||||
update = _safe_scalar_div(total_update, max_update, name=scope)
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean_average_precision)
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update)
|
||||
|
||||
return mean_average_precision, update
|
||||
return metrics.sparse_average_precision_at_k(
|
||||
k=k, predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def _select_class_id(ids, selected_id):
|
||||
@ -2329,12 +1979,10 @@ def streaming_mean_absolute_error(predictions, labels, weights=None,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
predictions, labels, weights = _remove_squeezable_dimensions(
|
||||
predictions, labels, weights)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
absolute_errors = math_ops.abs(predictions - labels)
|
||||
return streaming_mean(absolute_errors, weights, metrics_collections,
|
||||
updates_collections, name or 'mean_absolute_error')
|
||||
return metrics.mean_absolute_error(
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
|
||||
@ -2382,19 +2030,10 @@ def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
predictions, labels, weights = _remove_squeezable_dimensions(
|
||||
predictions, labels, weights)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
|
||||
predictions, normalizer = tensor_util.remove_squeezable_dimensions(
|
||||
predictions, normalizer)
|
||||
predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
|
||||
relative_errors = array_ops.where(
|
||||
math_ops.equal(normalizer, 0.0),
|
||||
array_ops.zeros_like(labels),
|
||||
math_ops.div(math_ops.abs(labels - predictions), normalizer))
|
||||
return streaming_mean(relative_errors, weights, metrics_collections,
|
||||
updates_collections, name or 'mean_relative_error')
|
||||
return metrics.mean_relative_error(
|
||||
normalizer=normalizer, predictions=predictions, labels=labels,
|
||||
weights=weights, metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_mean_squared_error(predictions, labels, weights=None,
|
||||
@ -2441,12 +2080,10 @@ def streaming_mean_squared_error(predictions, labels, weights=None,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
predictions, labels, weights = _remove_squeezable_dimensions(
|
||||
predictions, labels, weights)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
squared_error = math_ops.square(labels - predictions)
|
||||
return streaming_mean(squared_error, weights, metrics_collections,
|
||||
updates_collections, name or 'mean_squared_error')
|
||||
return metrics.mean_squared_error(
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_root_mean_squared_error(predictions, labels, weights=None,
|
||||
@ -2493,24 +2130,10 @@ def streaming_root_mean_squared_error(predictions, labels, weights=None,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
predictions, labels, weights = _remove_squeezable_dimensions(
|
||||
predictions, labels, weights)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
value_tensor, update_op = streaming_mean_squared_error(
|
||||
predictions, labels, weights, None, None,
|
||||
name or 'root_mean_squared_error')
|
||||
|
||||
root_mean_squared_error = math_ops.sqrt(value_tensor)
|
||||
with ops.control_dependencies([update_op]):
|
||||
update_op = math_ops.sqrt(update_op)
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, root_mean_squared_error)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return root_mean_squared_error, update_op
|
||||
return metrics.root_mean_squared_error(
|
||||
predictions=predictions, labels=labels, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_covariance(predictions,
|
||||
@ -2825,12 +2448,10 @@ def streaming_percentage_less(values, threshold, weights=None,
|
||||
or if either `metrics_collections` or `updates_collections` are not a list
|
||||
or tuple.
|
||||
"""
|
||||
is_below_threshold = math_ops.to_float(math_ops.less(values, threshold))
|
||||
return streaming_mean(is_below_threshold,
|
||||
weights,
|
||||
metrics_collections,
|
||||
updates_collections,
|
||||
name or 'percentage_below_threshold')
|
||||
return metrics.percentage_below(
|
||||
values=values, threshold=threshold, weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def streaming_mean_iou(predictions,
|
||||
@ -2881,65 +2502,10 @@ def streaming_mean_iou(predictions,
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
name, 'mean_iou', (predictions, labels, weights)):
|
||||
# Check if shape is compatible.
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
|
||||
# Local variable to accumulate the predictions in the confusion matrix.
|
||||
cm_dtype = dtypes.int64 if weights is not None else dtypes.float64
|
||||
total_cm = _create_local('total_confusion_matrix',
|
||||
shape=[num_classes, num_classes], dtype=cm_dtype)
|
||||
|
||||
# Cast the type to int64 required by confusion_matrix_ops.
|
||||
predictions = math_ops.to_int64(predictions)
|
||||
labels = math_ops.to_int64(labels)
|
||||
num_classes = math_ops.to_int64(num_classes)
|
||||
|
||||
# Flatten the input if its rank > 1.
|
||||
predictions_rank = predictions.get_shape().ndims
|
||||
if predictions_rank > 1:
|
||||
predictions = array_ops.reshape(predictions, [-1])
|
||||
|
||||
labels_rank = labels.get_shape().ndims
|
||||
if labels_rank > 1:
|
||||
labels = array_ops.reshape(labels, [-1])
|
||||
|
||||
if weights is not None:
|
||||
weights_rank = weights.get_shape().ndims
|
||||
if weights_rank > 1:
|
||||
weights = array_ops.reshape(weights, [-1])
|
||||
|
||||
# Accumulate the prediction to current confusion matrix.
|
||||
current_cm = confusion_matrix_ops.confusion_matrix(
|
||||
predictions, labels, num_classes, weights=weights, dtype=cm_dtype)
|
||||
update_op = state_ops.assign_add(total_cm, current_cm)
|
||||
|
||||
def compute_mean_iou(name):
|
||||
"""Compute the mean intersection-over-union via the confusion matrix."""
|
||||
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
|
||||
sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
|
||||
cm_diag = math_ops.to_float(array_ops.diag_part(total_cm))
|
||||
denominator = sum_over_row + sum_over_col - cm_diag
|
||||
|
||||
# If the value of the denominator is 0, set it to 1 to avoid
|
||||
# zero division.
|
||||
denominator = array_ops.where(
|
||||
math_ops.greater(denominator, 0),
|
||||
denominator,
|
||||
array_ops.ones_like(denominator))
|
||||
iou = math_ops.div(cm_diag, denominator)
|
||||
return math_ops.reduce_mean(iou, name=name)
|
||||
|
||||
mean_iou = compute_mean_iou('mean_iou')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean_iou)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return mean_iou, update_op
|
||||
return metrics.mean_iou(
|
||||
num_classes=num_classes, predictions=predictions, labels=labels,
|
||||
weights=weights, metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections, name=name)
|
||||
|
||||
|
||||
def _next_array_size(required_size, growth_factor=1.5):
|
||||
|
@ -39,6 +39,7 @@ py_library(
|
||||
":platform",
|
||||
":platform_test",
|
||||
":summary",
|
||||
":metrics",
|
||||
":layers",
|
||||
":training",
|
||||
":ops",
|
||||
@ -1312,6 +1313,39 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "confusion_matrix",
|
||||
srcs = ["ops/confusion_matrix.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":array_ops",
|
||||
":control_flow_ops",
|
||||
":framework",
|
||||
":math_ops",
|
||||
":sparse_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "metrics",
|
||||
srcs = ["ops/metrics.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":array_ops",
|
||||
":check_ops",
|
||||
":confusion_matrix",
|
||||
":control_flow_ops",
|
||||
":framework",
|
||||
":math_ops",
|
||||
":nn",
|
||||
":sets",
|
||||
":sparse_ops",
|
||||
":state_ops",
|
||||
":variable_scope",
|
||||
":variables",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "special_math_ops",
|
||||
srcs = ["ops/special_math_ops.py"],
|
||||
@ -1334,6 +1368,7 @@ py_library(
|
||||
":array_ops",
|
||||
":check_ops",
|
||||
":clip_ops",
|
||||
":confusion_matrix",
|
||||
":control_flow_ops",
|
||||
":data_flow_grad",
|
||||
":data_flow_ops",
|
||||
|
@ -83,6 +83,7 @@ from tensorflow.python.ops.standard_ops import *
|
||||
|
||||
# Bring in subpackages.
|
||||
from tensorflow.python.layers import layers
|
||||
from tensorflow.python.ops import metrics
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.ops import sdca_ops as sdca
|
||||
@ -118,6 +119,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import framework_lib
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import confusion_matrix as confusion_matrix_m
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import histogram_ops
|
||||
@ -220,6 +222,7 @@ _allowed_symbols.extend([
|
||||
'image',
|
||||
'logging',
|
||||
'losses',
|
||||
'metrics',
|
||||
'newaxis',
|
||||
'nn',
|
||||
'python_io',
|
||||
@ -246,10 +249,10 @@ _allowed_symbols.extend([
|
||||
# referenced in the whitelist.
|
||||
remove_undocumented(__name__, _allowed_symbols,
|
||||
[framework_lib, array_ops, client_lib, check_ops,
|
||||
compat, constant_op, control_flow_ops, functional_ops,
|
||||
histogram_ops, io_ops, losses, math_ops, nn,
|
||||
resource_loader, resources, sets, script_ops, session_ops,
|
||||
sparse_ops, state_ops, string_ops, summary,
|
||||
compat, constant_op, control_flow_ops, confusion_matrix_m,
|
||||
functional_ops, histogram_ops, io_ops, losses, math_ops,
|
||||
metrics, nn, resource_loader, resources, sets, script_ops,
|
||||
session_ops, sparse_ops, state_ops, string_ops, summary,
|
||||
tensor_array_ops, train, layers])
|
||||
|
||||
# Special dunders that we choose to export:
|
||||
|
@ -260,6 +260,7 @@ EXCLUDE = frozenset(["tf.contrib.learn.monitors.NanLossDuringTrainingError",
|
||||
"tf.contrib.framework.get_global_step",
|
||||
"tf.contrib.learn.NanLossDuringTrainingError",
|
||||
"tf.contrib.layers.stack",
|
||||
"tf.confusion_matrix",
|
||||
"tf.nn.rnn_cell.RNNCell",
|
||||
"tf.nn.rnn_cell.BasicRNNCell",
|
||||
"tf.nn.rnn_cell.BasicLSTMCell",
|
||||
|
@ -1385,6 +1385,21 @@ tf_py_test(
|
||||
additional_deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "metrics_test",
|
||||
size = "small",
|
||||
srcs = ["metrics_test.py"],
|
||||
additional_deps = ["//tensorflow:tensorflow_py"],
|
||||
shard_count = 3,
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "confusion_matrix_test",
|
||||
size = "small",
|
||||
srcs = ["confusion_matrix_test.py"],
|
||||
additional_deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -28,8 +28,8 @@ class ConfusionMatrixTest(tf.test.TestCase):
|
||||
def _testConfMatrix(self, predictions, labels, truth, weights=None):
|
||||
with self.test_session():
|
||||
dtype = predictions.dtype
|
||||
ans = tf.contrib.metrics.confusion_matrix(
|
||||
predictions, labels, dtype=dtype, weights=weights)
|
||||
ans = tf.confusion_matrix(
|
||||
labels, predictions, dtype=dtype, weights=weights)
|
||||
tf_ans = ans.eval()
|
||||
self.assertAllClose(tf_ans, truth, atol=1e-10)
|
||||
self.assertEqual(tf_ans.dtype, dtype)
|
||||
@ -69,8 +69,8 @@ class ConfusionMatrixTest(tf.test.TestCase):
|
||||
lab = tf.concat(0, [tf.zeros([20], dtype=tf_dtype),
|
||||
tf.ones([20], dtype=tf_dtype)])
|
||||
|
||||
cm = tf.contrib.metrics.confusion_matrix(
|
||||
data, lab, dtype=tf_dtype, num_classes=2)
|
||||
cm = tf.confusion_matrix(
|
||||
lab, data, dtype=tf_dtype, num_classes=2)
|
||||
|
||||
d, l, cm_out = sess.run([data, lab, cm], {m_neg: 0.0,
|
||||
m_pos: 1.0,
|
||||
@ -157,28 +157,28 @@ class ConfusionMatrixTest(tf.test.TestCase):
|
||||
predictions = np.asarray([[1, 2, 3]])
|
||||
labels = np.asarray([1, 2, 3])
|
||||
self.assertRaisesRegexp(ValueError, "an not squeeze dim",
|
||||
tf.contrib.metrics.confusion_matrix, predictions,
|
||||
labels)
|
||||
tf.confusion_matrix,
|
||||
predictions, labels)
|
||||
|
||||
predictions = np.asarray([1, 2, 3])
|
||||
labels = np.asarray([[1, 2, 3]])
|
||||
self.assertRaisesRegexp(ValueError, "an not squeeze dim",
|
||||
tf.contrib.metrics.confusion_matrix, predictions,
|
||||
labels)
|
||||
tf.confusion_matrix,
|
||||
predictions, labels)
|
||||
|
||||
def testInputDifferentSize(self):
|
||||
predictions = np.asarray([1, 2, 3])
|
||||
labels = np.asarray([1, 2])
|
||||
self.assertRaisesRegexp(ValueError, "must be equal",
|
||||
tf.contrib.metrics.confusion_matrix, predictions,
|
||||
labels)
|
||||
tf.confusion_matrix,
|
||||
predictions, labels)
|
||||
|
||||
def testOutputIsInt32(self):
|
||||
predictions = np.arange(2)
|
||||
labels = np.arange(2)
|
||||
with self.test_session():
|
||||
cm = tf.contrib.metrics.confusion_matrix(
|
||||
predictions, labels, dtype=dtypes.int32)
|
||||
cm = tf.confusion_matrix(
|
||||
labels, predictions, dtype=dtypes.int32)
|
||||
tf_cm = cm.eval()
|
||||
self.assertEqual(tf_cm.dtype, np.int32)
|
||||
|
||||
@ -186,8 +186,8 @@ class ConfusionMatrixTest(tf.test.TestCase):
|
||||
predictions = np.arange(2)
|
||||
labels = np.arange(2)
|
||||
with self.test_session():
|
||||
cm = tf.contrib.metrics.confusion_matrix(
|
||||
predictions, labels, dtype=dtypes.int64)
|
||||
cm = tf.confusion_matrix(
|
||||
labels, predictions, dtype=dtypes.int64)
|
||||
tf_cm = cm.eval()
|
||||
self.assertEqual(tf_cm.dtype, np.int64)
|
||||
|
3360
tensorflow/python/kernel_tests/metrics_test.py
Normal file
3360
tensorflow/python/kernel_tests/metrics_test.py
Normal file
File diff suppressed because it is too large
Load Diff
163
tensorflow/python/ops/confusion_matrix.py
Normal file
163
tensorflow/python/ops/confusion_matrix.py
Normal file
@ -0,0 +1,163 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Confusion matrix related utilities.
|
||||
|
||||
|
||||
@@remove_squeezable_dimensions
|
||||
@@confusion_matrix
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
|
||||
|
||||
def remove_squeezable_dimensions(labels, predictions, name=None):
|
||||
"""Squeeze last dim if ranks of `predictions` and `labels` differ by 1.
|
||||
|
||||
This will use static shape if available. Otherwise, it will add graph
|
||||
operations, which could result in a performance hit.
|
||||
|
||||
Args:
|
||||
labels: Label values, a `Tensor` whose dimensions match `predictions`.
|
||||
predictions: Predicted values, a `Tensor` of arbitrary dimensions.
|
||||
name: Name of the op.
|
||||
|
||||
Returns:
|
||||
Tuple of `labels` and `predictions`, possibly with last dim squeezed.
|
||||
"""
|
||||
with ops.name_scope(name, 'remove_squeezable_dimensions',
|
||||
[labels, predictions]):
|
||||
predictions = ops.convert_to_tensor(predictions)
|
||||
labels = ops.convert_to_tensor(labels)
|
||||
predictions_shape = predictions.get_shape()
|
||||
predictions_rank = predictions_shape.ndims
|
||||
labels_shape = labels.get_shape()
|
||||
labels_rank = labels_shape.ndims
|
||||
if (labels_rank is not None) and (predictions_rank is not None):
|
||||
# Use static rank.
|
||||
rank_diff = predictions_rank - labels_rank
|
||||
if rank_diff == -1:
|
||||
labels = array_ops.squeeze(labels, [-1])
|
||||
elif rank_diff == 1:
|
||||
predictions = array_ops.squeeze(predictions, [-1])
|
||||
return labels, predictions
|
||||
|
||||
# Use dynamic rank.
|
||||
rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
|
||||
if (predictions_rank is None) or (
|
||||
predictions_shape.dims[-1].is_compatible_with(1)):
|
||||
predictions = control_flow_ops.cond(
|
||||
math_ops.equal(1, rank_diff),
|
||||
lambda: array_ops.squeeze(predictions, [-1]),
|
||||
lambda: predictions)
|
||||
if (labels_rank is None) or (
|
||||
labels_shape.dims[-1].is_compatible_with(1)):
|
||||
labels = control_flow_ops.cond(
|
||||
math_ops.equal(-1, rank_diff),
|
||||
lambda: array_ops.squeeze(labels, [-1]),
|
||||
lambda: labels)
|
||||
return labels, predictions
|
||||
|
||||
|
||||
def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32,
|
||||
name=None, weights=None):
|
||||
"""Computes the confusion matrix from predictions and labels.
|
||||
|
||||
Calculate the Confusion Matrix for a pair of prediction and
|
||||
label 1-D int arrays.
|
||||
|
||||
The matrix rows represent the prediction labels and the columns
|
||||
represents the real labels. The confusion matrix is always a 2-D array
|
||||
of shape `[n, n]`, where `n` is the number of valid labels for a given
|
||||
classification task. Both prediction and labels must be 1-D arrays of
|
||||
the same shape in order for this function to work.
|
||||
|
||||
If `num_classes` is None, then `num_classes` will be set to the one plus
|
||||
the maximum value in either predictions or labels.
|
||||
Class labels are expected to start at 0. E.g., if `num_classes` was
|
||||
three, then the possible labels would be `[0, 1, 2]`.
|
||||
|
||||
If `weights` is not `None`, then each prediction contributes its
|
||||
corresponding weight to the total value of the confusion matrix cell.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
tf.contrib.metrics.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
|
||||
[[0 0 0 0 0]
|
||||
[0 0 1 0 0]
|
||||
[0 0 1 0 0]
|
||||
[0 0 0 0 0]
|
||||
[0 0 0 0 1]]
|
||||
```
|
||||
|
||||
Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
|
||||
resulting in a 5x5 confusion matrix.
|
||||
|
||||
Args:
|
||||
labels: A 1-D representing the real labels for the classification task.
|
||||
predictions: A 1-D array representing the predictions for a given
|
||||
classification.
|
||||
num_classes: The possible number of labels the classification task can
|
||||
have. If this value is not provided, it will be calculated
|
||||
using both predictions and labels array.
|
||||
dtype: Data type of the confusion matrix.
|
||||
name: Scope name.
|
||||
weights: An optional `Tensor` whose shape matches `predictions`.
|
||||
|
||||
Returns:
|
||||
A k X k matrix representing the confusion matrix, where k is the number of
|
||||
possible labels in the classification task.
|
||||
|
||||
Raises:
|
||||
ValueError: If both predictions and labels are not 1-D vectors and have
|
||||
mismatched shapes, or if `weights` is not `None` and its shape doesn't
|
||||
match `predictions`.
|
||||
"""
|
||||
with ops.name_scope(name, 'confusion_matrix',
|
||||
[predictions, labels, num_classes]) as name:
|
||||
labels, predictions = remove_squeezable_dimensions(
|
||||
ops.convert_to_tensor(labels, name='labels'),
|
||||
ops.convert_to_tensor(
|
||||
predictions, name='predictions'))
|
||||
predictions = math_ops.cast(predictions, dtypes.int64)
|
||||
labels = math_ops.cast(labels, dtypes.int64)
|
||||
|
||||
if num_classes is None:
|
||||
num_classes = math_ops.maximum(math_ops.reduce_max(predictions),
|
||||
math_ops.reduce_max(labels)) + 1
|
||||
|
||||
if weights is not None:
|
||||
predictions.get_shape().assert_is_compatible_with(weights.get_shape())
|
||||
weights = math_ops.cast(weights, dtype)
|
||||
|
||||
shape = array_ops.pack([num_classes, num_classes])
|
||||
indices = array_ops.transpose(array_ops.pack([predictions, labels]))
|
||||
values = (array_ops.ones_like(predictions, dtype)
|
||||
if weights is None else weights)
|
||||
cm_sparse = sparse_tensor.SparseTensor(
|
||||
indices=indices, values=values, shape=math_ops.to_int64(shape))
|
||||
zero_matrix = array_ops.zeros(math_ops.to_int32(shape), dtype)
|
||||
|
||||
return sparse_ops.sparse_add(zero_matrix, cm_sparse)
|
2588
tensorflow/python/ops/metrics.py
Normal file
2588
tensorflow/python/ops/metrics.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -39,6 +39,7 @@ from tensorflow.python.ops.check_ops import *
|
||||
from tensorflow.python.ops.clip_ops import *
|
||||
from tensorflow.python.ops.special_math_ops import *
|
||||
# TODO(vrv): Switch to import * once we're okay with exposing the module.
|
||||
from tensorflow.python.ops.confusion_matrix import confusion_matrix
|
||||
from tensorflow.python.ops.control_flow_ops import Assert
|
||||
from tensorflow.python.ops.control_flow_ops import group
|
||||
from tensorflow.python.ops.control_flow_ops import no_op
|
||||
@ -91,6 +92,7 @@ from tensorflow.python.framework import constant_op as _constant_op
|
||||
from tensorflow.python.ops import array_ops as _array_ops
|
||||
from tensorflow.python.ops import check_ops as _check_ops
|
||||
from tensorflow.python.ops import clip_ops as _clip_ops
|
||||
from tensorflow.python.ops import confusion_matrix as _confusion_matrix
|
||||
from tensorflow.python.ops import control_flow_ops as _control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops as _data_flow_ops
|
||||
from tensorflow.python.ops import functional_ops as _functional_ops
|
||||
@ -244,6 +246,7 @@ _allowed_symbols_misc = [
|
||||
"parse_single_sequence_example",
|
||||
"serialize_many_sparse",
|
||||
"serialize_sparse",
|
||||
"confusion_matrix",
|
||||
]
|
||||
|
||||
_allowed_symbols = (_allowed_symbols_array_ops +
|
||||
@ -262,6 +265,7 @@ remove_undocumented(__name__, _allowed_symbols,
|
||||
_array_ops,
|
||||
_check_ops,
|
||||
_clip_ops,
|
||||
_confusion_matrix,
|
||||
_control_flow_ops,
|
||||
_constant_op,
|
||||
_data_flow_ops,
|
||||
|
Loading…
Reference in New Issue
Block a user