STT-tensorflow/tensorflow/python/kernel_tests/confusion_matrix_test.py
Gaurav Jain f618ab4955 Move away from deprecated asserts
- assertEquals -> assertEqual
- assertRaisesRegexp -> assertRegexpMatches
- assertRegexpMatches -> assertRegex

PiperOrigin-RevId: 319118081
Change-Id: Ieb457128522920ab55d6b69a7f244ab798a7d689
2020-06-30 16:10:22 -07:00

480 lines
18 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for confusion_matrix_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import confusion_matrix
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
class ConfusionMatrixTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testExample(self):
"""This is a test of the example provided in pydoc."""
with self.cached_session():
self.assertAllEqual([
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 1]
], self.evaluate(confusion_matrix.confusion_matrix(
labels=[1, 2, 4], predictions=[2, 2, 4])))
def _testConfMatrix(self, labels, predictions, truth, weights=None,
num_classes=None):
with self.cached_session():
dtype = predictions.dtype
ans = confusion_matrix.confusion_matrix(
labels, predictions, dtype=dtype, weights=weights,
num_classes=num_classes).eval()
self.assertAllClose(truth, ans, atol=1e-10)
self.assertEqual(ans.dtype, dtype)
def _testBasic(self, dtype):
labels = np.arange(5, dtype=dtype)
predictions = np.arange(5, dtype=dtype)
truth = np.asarray(
[[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]],
dtype=dtype)
self._testConfMatrix(labels=labels, predictions=predictions, truth=truth)
@test_util.run_deprecated_v1
def testInt32Basic(self):
self._testBasic(dtype=np.int32)
@test_util.run_deprecated_v1
def testInt64Basic(self):
self._testBasic(dtype=np.int64)
def _testConfMatrixOnTensors(self, tf_dtype, np_dtype):
with self.cached_session() as sess:
m_neg = array_ops.placeholder(dtype=dtypes.float32)
m_pos = array_ops.placeholder(dtype=dtypes.float32)
s = array_ops.placeholder(dtype=dtypes.float32)
neg = random_ops.random_normal(
[20], mean=m_neg, stddev=s, dtype=dtypes.float32)
pos = random_ops.random_normal(
[20], mean=m_pos, stddev=s, dtype=dtypes.float32)
data = array_ops.concat([neg, pos], 0)
data = math_ops.cast(math_ops.round(data), tf_dtype)
data = math_ops.minimum(math_ops.maximum(data, 0), 1)
lab = array_ops.concat(
[
array_ops.zeros(
[20], dtype=tf_dtype), array_ops.ones(
[20], dtype=tf_dtype)
],
0)
cm = confusion_matrix.confusion_matrix(
lab, data, dtype=tf_dtype, num_classes=2)
d, l, cm_out = sess.run([data, lab, cm], {m_neg: 0.0, m_pos: 1.0, s: 1.0})
truth = np.zeros([2, 2], dtype=np_dtype)
for i in xrange(len(d)):
truth[l[i], d[i]] += 1
self.assertEqual(cm_out.dtype, np_dtype)
self.assertAllClose(cm_out, truth, atol=1e-10)
@test_util.run_deprecated_v1
def testOnTensors_int32(self):
self._testConfMatrixOnTensors(dtypes.int32, np.int32)
@test_util.run_deprecated_v1
def testOnTensors_int64(self):
self._testConfMatrixOnTensors(dtypes.int64, np.int64)
def _testDifferentLabelsInPredictionAndTarget(self, dtype):
labels = np.asarray([4, 5, 6], dtype=dtype)
predictions = np.asarray([1, 2, 3], dtype=dtype)
truth = np.asarray(
[[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0]],
dtype=dtype)
self._testConfMatrix(labels=labels, predictions=predictions, truth=truth)
@test_util.run_deprecated_v1
def testInt32DifferentLabels(self, dtype=np.int32):
self._testDifferentLabelsInPredictionAndTarget(dtype)
@test_util.run_deprecated_v1
def testInt64DifferentLabels(self, dtype=np.int64):
self._testDifferentLabelsInPredictionAndTarget(dtype)
def _testMultipleLabels(self, dtype):
labels = np.asarray([1, 1, 2, 3, 5, 1, 3, 6, 3, 1], dtype=dtype)
predictions = np.asarray([1, 1, 2, 3, 5, 6, 1, 2, 3, 4], dtype=dtype)
truth = np.asarray(
[[0, 0, 0, 0, 0, 0, 0],
[0, 2, 0, 0, 1, 0, 1],
[0, 0, 1, 0, 0, 0, 0],
[0, 1, 0, 2, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 0]],
dtype=dtype)
self._testConfMatrix(labels=labels, predictions=predictions, truth=truth)
@test_util.run_deprecated_v1
def testInt32MultipleLabels(self, dtype=np.int32):
self._testMultipleLabels(dtype)
@test_util.run_deprecated_v1
def testInt64MultipleLabels(self, dtype=np.int64):
self._testMultipleLabels(dtype)
@test_util.run_deprecated_v1
def testWeighted(self):
labels = np.arange(5, dtype=np.int32)
predictions = np.arange(5, dtype=np.int32)
weights = np.arange(5, dtype=np.int32)
truth = np.asarray(
[[0, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 2, 0, 0],
[0, 0, 0, 3, 0],
[0, 0, 0, 0, 4]],
dtype=np.int32)
self._testConfMatrix(
labels=labels, predictions=predictions, weights=weights, truth=truth)
@test_util.run_deprecated_v1
def testLabelsTooLarge(self):
labels = np.asarray([1, 1, 0, 3, 5], dtype=np.int32)
predictions = np.asarray([2, 1, 0, 2, 2], dtype=np.int32)
with self.assertRaisesWithPredicateMatch(errors_impl.InvalidArgumentError,
"`labels`.*out of bound"):
self._testConfMatrix(
labels=labels, predictions=predictions, num_classes=3, truth=None)
def testLabelsNegative(self):
labels = np.asarray([1, 1, 0, -1, -1], dtype=np.int32)
predictions = np.asarray([2, 1, 0, 2, 2], dtype=np.int32)
with self.assertRaisesOpError("`labels`.*negative values"):
self._testConfMatrix(
labels=labels, predictions=predictions, num_classes=3, truth=None)
@test_util.run_deprecated_v1
def testPredictionsTooLarge(self):
labels = np.asarray([1, 1, 0, 2, 2], dtype=np.int32)
predictions = np.asarray([2, 1, 0, 3, 5], dtype=np.int32)
with self.assertRaisesWithPredicateMatch(errors_impl.InvalidArgumentError,
"`predictions`.*out of bound"):
self._testConfMatrix(
labels=labels, predictions=predictions, num_classes=3, truth=None)
def testPredictionsNegative(self):
labels = np.asarray([1, 1, 0, 2, 2], dtype=np.int32)
predictions = np.asarray([2, 1, 0, -1, -1], dtype=np.int32)
with self.assertRaisesOpError("`predictions`.*negative values"):
self._testConfMatrix(
labels=labels, predictions=predictions, num_classes=3, truth=None)
@test_util.run_deprecated_v1
def testInputDifferentSize(self):
labels = np.asarray([1, 2])
predictions = np.asarray([1, 2, 3])
self.assertRaisesRegex(ValueError, "must be equal",
confusion_matrix.confusion_matrix, predictions,
labels)
def testOutputIsInt32(self):
labels = np.arange(2)
predictions = np.arange(2)
with self.cached_session():
cm = confusion_matrix.confusion_matrix(
labels, predictions, dtype=dtypes.int32)
tf_cm = self.evaluate(cm)
self.assertEqual(tf_cm.dtype, np.int32)
def testOutputIsInt64(self):
labels = np.arange(2)
predictions = np.arange(2)
with self.cached_session():
cm = confusion_matrix.confusion_matrix(
labels, predictions, dtype=dtypes.int64)
tf_cm = self.evaluate(cm)
self.assertEqual(tf_cm.dtype, np.int64)
class RemoveSqueezableDimensionsTest(test.TestCase):
@test_util.run_deprecated_v1
def testBothScalarShape(self):
label_values = 1.0
prediction_values = 0.0
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values))
labels_placeholder = array_ops.placeholder(dtype=dtypes.float32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.float32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
with self.cached_session():
self.assertAllEqual(label_values, self.evaluate(static_labels))
self.assertAllEqual(prediction_values, self.evaluate(static_predictions))
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
@test_util.run_deprecated_v1
def testSameShape(self):
label_values = np.ones(shape=(2, 3, 1))
prediction_values = np.zeros_like(label_values)
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
with self.cached_session():
self.assertAllEqual(label_values, self.evaluate(static_labels))
self.assertAllEqual(prediction_values, self.evaluate(static_predictions))
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
@test_util.run_deprecated_v1
def testSameShapeExpectedRankDiff0(self):
label_values = np.ones(shape=(2, 3, 1))
prediction_values = np.zeros_like(label_values)
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values, expected_rank_diff=0))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder, expected_rank_diff=0))
with self.cached_session():
self.assertAllEqual(label_values, self.evaluate(static_labels))
self.assertAllEqual(prediction_values, self.evaluate(static_predictions))
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
@test_util.run_deprecated_v1
def testSqueezableLabels(self):
label_values = np.ones(shape=(2, 3, 1))
prediction_values = np.zeros(shape=(2, 3))
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
expected_label_values = np.reshape(label_values, newshape=(2, 3))
with self.cached_session():
self.assertAllEqual(expected_label_values, self.evaluate(static_labels))
self.assertAllEqual(prediction_values, self.evaluate(static_predictions))
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
expected_label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
@test_util.run_deprecated_v1
def testSqueezableLabelsExpectedRankDiffPlus1(self):
label_values = np.ones(shape=(2, 3, 1))
prediction_values = np.zeros(shape=(2, 3, 5))
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values, expected_rank_diff=1))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder, expected_rank_diff=1))
expected_label_values = np.reshape(label_values, newshape=(2, 3))
with self.cached_session():
self.assertAllEqual(expected_label_values, self.evaluate(static_labels))
self.assertAllEqual(prediction_values, self.evaluate(static_predictions))
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
expected_label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
@test_util.run_deprecated_v1
def testSqueezablePredictions(self):
label_values = np.ones(shape=(2, 3))
prediction_values = np.zeros(shape=(2, 3, 1))
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
with self.cached_session():
self.assertAllEqual(label_values, self.evaluate(static_labels))
self.assertAllEqual(expected_prediction_values,
self.evaluate(static_predictions))
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
expected_prediction_values,
dynamic_predictions.eval(feed_dict=feed_dict))
@test_util.run_deprecated_v1
def testSqueezablePredictionsExpectedRankDiffMinus1(self):
label_values = np.ones(shape=(2, 3, 5))
prediction_values = np.zeros(shape=(2, 3, 1))
static_labels, static_predictions = (
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values, expected_rank_diff=-1))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder, expected_rank_diff=-1))
expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
with self.cached_session():
self.assertAllEqual(label_values, self.evaluate(static_labels))
self.assertAllEqual(expected_prediction_values,
self.evaluate(static_predictions))
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict))
self.assertAllEqual(
expected_prediction_values,
dynamic_predictions.eval(feed_dict=feed_dict))
@test_util.run_deprecated_v1
def testUnsqueezableLabels(self):
label_values = np.ones(shape=(2, 3, 2))
prediction_values = np.zeros(shape=(2, 3))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
_, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions(labels_placeholder,
predictions_placeholder))
with self.cached_session():
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
@test_util.run_deprecated_v1
def testUnsqueezablePredictions(self):
label_values = np.ones(shape=(2, 3))
prediction_values = np.zeros(shape=(2, 3, 2))
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, _ = (
confusion_matrix.remove_squeezable_dimensions(labels_placeholder,
predictions_placeholder))
with self.cached_session():
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
}
self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict))
if __name__ == "__main__":
test.main()