From 4d9a24026cbef4115129538456dd3fbbbd477b06 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 12 Jan 2017 12:27:19 -0800 Subject: [PATCH] Transpose confusion matrix output to make it consistent with sklearn. Change: 144356967 --- .../metrics/python/ops/metric_ops_test.py | 6 +- .../kernel_tests/confusion_matrix_test.py | 57 ++++++++++--------- .../python/kernel_tests/metrics_test.py | 6 +- tensorflow/python/ops/confusion_matrix.py | 21 ++++--- 4 files changed, 45 insertions(+), 45 deletions(-) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 35efaf14d1b..3e2e408e6f9 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -4486,7 +4486,7 @@ class StreamingMeanIOUTest(test.TestCase): num_classes) sess.run(variables.local_variables_initializer()) confusion_matrix = update_op.eval() - self.assertAllEqual([[3, 2], [0, 5]], confusion_matrix) + self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix) desired_miou = np.mean([3. / 5., 5. / 7.]) self.assertAlmostEqual(desired_miou, miou.eval()) @@ -4509,7 +4509,7 @@ class StreamingMeanIOUTest(test.TestCase): miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) - self.assertAllEqual([[0, 40], [0, 0]], update_op.eval()) + self.assertAllEqual([[0, 0], [40, 0]], update_op.eval()) self.assertEqual(0., miou.eval()) def testResultsWithSomeMissing(self): @@ -4540,7 +4540,7 @@ class StreamingMeanIOUTest(test.TestCase): miou, update_op = metrics.streaming_mean_iou( predictions, labels, num_classes, weights=weights) sess.run(variables.local_variables_initializer()) - self.assertAllEqual([[2, 2], [0, 4]], update_op.eval()) + self.assertAllEqual([[2, 0], [2, 4]], update_op.eval()) desired_miou = np.mean([2. / 4., 4. / 6.]) self.assertAlmostEqual(desired_miou, miou.eval()) diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py index f495a5a1e5c..cf882091488 100644 --- a/tensorflow/python/kernel_tests/confusion_matrix_test.py +++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py @@ -36,14 +36,14 @@ class ConfusionMatrixTest(test.TestCase): with self.test_session(): self.assertAllEqual([ [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 1, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1] ], confusion_matrix.confusion_matrix( labels=[1, 2, 4], predictions=[2, 2, 4]).eval()) - def _testConfMatrix(self, predictions, labels, truth, weights=None): + def _testConfMatrix(self, labels, predictions, truth, weights=None): with self.test_session(): dtype = predictions.dtype ans = confusion_matrix.confusion_matrix( @@ -52,8 +52,8 @@ class ConfusionMatrixTest(test.TestCase): self.assertEqual(ans.dtype, dtype) def _testBasic(self, dtype): - predictions = np.arange(5, dtype=dtype) labels = np.arange(5, dtype=dtype) + predictions = np.arange(5, dtype=dtype) truth = np.asarray( [[1, 0, 0, 0, 0], @@ -63,7 +63,7 @@ class ConfusionMatrixTest(test.TestCase): [0, 0, 0, 0, 1]], dtype=dtype) - self._testConfMatrix(predictions=predictions, labels=labels, truth=truth) + self._testConfMatrix(labels=labels, predictions=predictions, truth=truth) def testInt32Basic(self): self._testBasic(dtype=np.int32) @@ -104,32 +104,32 @@ class ConfusionMatrixTest(test.TestCase): except NameError: # In Python 3. range_builder = range for i in range_builder(len(d)): - truth[d[i], l[i]] += 1 + truth[l[i], d[i]] += 1 self.assertEqual(cm_out.dtype, np_dtype) self.assertAllClose(cm_out, truth, atol=1e-10) - def _testOnTensors_int32(self): + def testOnTensors_int32(self): self._testConfMatrixOnTensors(dtypes.int32, np.int32) def testOnTensors_int64(self): self._testConfMatrixOnTensors(dtypes.int64, np.int64) def _testDifferentLabelsInPredictionAndTarget(self, dtype): - predictions = np.asarray([1, 2, 3], dtype=dtype) labels = np.asarray([4, 5, 6], dtype=dtype) + predictions = np.asarray([1, 2, 3], dtype=dtype) truth = np.asarray( [[0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0]], + [0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0]], dtype=dtype) - self._testConfMatrix(predictions=predictions, labels=labels, truth=truth) + self._testConfMatrix(labels=labels, predictions=predictions, truth=truth) def testInt32DifferentLabels(self, dtype=np.int32): self._testDifferentLabelsInPredictionAndTarget(dtype) @@ -138,20 +138,20 @@ class ConfusionMatrixTest(test.TestCase): self._testDifferentLabelsInPredictionAndTarget(dtype) def _testMultipleLabels(self, dtype): - predictions = np.asarray([1, 1, 2, 3, 5, 6, 1, 2, 3, 4], dtype=dtype) labels = np.asarray([1, 1, 2, 3, 5, 1, 3, 6, 3, 1], dtype=dtype) + predictions = np.asarray([1, 1, 2, 3, 5, 6, 1, 2, 3, 4], dtype=dtype) truth = np.asarray( [[0, 0, 0, 0, 0, 0, 0], - [0, 2, 0, 1, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 1], - [0, 0, 0, 2, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0], + [0, 2, 0, 0, 1, 0, 1], + [0, 0, 1, 0, 0, 0, 0], + [0, 1, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0]], + [0, 0, 1, 0, 0, 0, 0]], dtype=dtype) - self._testConfMatrix(predictions=predictions, labels=labels, truth=truth) + self._testConfMatrix(labels=labels, predictions=predictions, truth=truth) def testInt32MultipleLabels(self, dtype=np.int32): self._testMultipleLabels(dtype) @@ -160,8 +160,8 @@ class ConfusionMatrixTest(test.TestCase): self._testMultipleLabels(dtype) def testWeighted(self): - predictions = np.arange(5, dtype=np.int32) labels = np.arange(5, dtype=np.int32) + predictions = np.arange(5, dtype=np.int32) weights = constant_op.constant(np.arange(5, dtype=np.int32)) truth = np.asarray( @@ -173,31 +173,32 @@ class ConfusionMatrixTest(test.TestCase): dtype=np.int32) self._testConfMatrix( - predictions=predictions, labels=labels, weights=weights, truth=truth) + labels=labels, predictions=predictions, weights=weights, truth=truth) - def testInvalidRank(self): - predictions = np.asarray([[1, 2, 3]]) + def testInvalidRank_predictionsTooBig(self): labels = np.asarray([1, 2, 3]) + predictions = np.asarray([[1, 2, 3]]) self.assertRaisesRegexp(ValueError, "an not squeeze dim", confusion_matrix.confusion_matrix, predictions, labels) - predictions = np.asarray([1, 2, 3]) + def testInvalidRank_predictionsTooSmall(self): labels = np.asarray([[1, 2, 3]]) + predictions = np.asarray([1, 2, 3]) self.assertRaisesRegexp(ValueError, "an not squeeze dim", confusion_matrix.confusion_matrix, predictions, labels) def testInputDifferentSize(self): - predictions = np.asarray([1, 2, 3]) labels = np.asarray([1, 2]) + predictions = np.asarray([1, 2, 3]) self.assertRaisesRegexp(ValueError, "must be equal", confusion_matrix.confusion_matrix, predictions, labels) def testOutputIsInt32(self): - predictions = np.arange(2) labels = np.arange(2) + predictions = np.arange(2) with self.test_session(): cm = confusion_matrix.confusion_matrix( labels, predictions, dtype=dtypes.int32) @@ -205,8 +206,8 @@ class ConfusionMatrixTest(test.TestCase): self.assertEqual(tf_cm.dtype, np.int32) def testOutputIsInt64(self): - predictions = np.arange(2) labels = np.arange(2) + predictions = np.arange(2) with self.test_session(): cm = confusion_matrix.confusion_matrix( labels, predictions, dtype=dtypes.int64) diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py index 0ea14e82bc7..07d805b90ec 100644 --- a/tensorflow/python/kernel_tests/metrics_test.py +++ b/tensorflow/python/kernel_tests/metrics_test.py @@ -3296,7 +3296,7 @@ class MeanIOUTest(test.TestCase): miou, update_op = metrics.mean_iou(labels, predictions, num_classes) sess.run(variables.local_variables_initializer()) confusion_matrix = update_op.eval() - self.assertAllEqual([[3, 2], [0, 5]], confusion_matrix) + self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix) desired_miou = np.mean([3. / 5., 5. / 7.]) self.assertAlmostEqual(desired_miou, miou.eval()) @@ -3317,7 +3317,7 @@ class MeanIOUTest(test.TestCase): with self.test_session() as sess: miou, update_op = metrics.mean_iou(labels, predictions, num_classes) sess.run(variables.local_variables_initializer()) - self.assertAllEqual([[0, 40], [0, 0]], update_op.eval()) + self.assertAllEqual([[0, 0], [40, 0]], update_op.eval()) self.assertEqual(0., miou.eval()) def testResultsWithSomeMissing(self): @@ -3348,7 +3348,7 @@ class MeanIOUTest(test.TestCase): miou, update_op = metrics.mean_iou( labels, predictions, num_classes, weights=weights) sess.run(variables.local_variables_initializer()) - self.assertAllEqual([[2, 2], [0, 4]], update_op.eval()) + self.assertAllEqual([[2, 0], [2, 4]], update_op.eval()) desired_miou = np.mean([2. / 4., 4. / 6.]) self.assertAlmostEqual(desired_miou, miou.eval()) diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py index b8e3791f91b..628853545e9 100644 --- a/tensorflow/python/ops/confusion_matrix.py +++ b/tensorflow/python/ops/confusion_matrix.py @@ -87,11 +87,11 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, Calculate the Confusion Matrix for a pair of prediction and label 1-D int arrays. - The matrix rows represent the prediction labels and the columns - represents the real labels. The confusion matrix is always a 2-D array - of shape `[n, n]`, where `n` is the number of valid labels for a given - classification task. Both prediction and labels must be 1-D arrays of - the same shape in order for this function to work. + The matrix columns represent the prediction labels and the rows represent the + real labels. The confusion matrix is always a 2-D array of shape `[n, n]`, + where `n` is the number of valid labels for a given classification task. Both + prediction and labels must be 1-D arrays of the same shape in order for this + function to work. If `num_classes` is None, then `num_classes` will be set to the one plus the maximum value in either predictions or labels. @@ -106,8 +106,8 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, ```python tf.contrib.metrics.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> [[0 0 0 0 0] - [0 0 0 0 0] - [0 1 1 0 0] + [0 0 1 0 0] + [0 0 1 0 0] [0 0 0 0 0] [0 0 0 0 1]] ``` @@ -116,9 +116,8 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, resulting in a 5x5 confusion matrix. Args: - labels: A 1-D representing the real labels for the classification task. - predictions: A 1-D array representing the predictions for a given - classification. + labels: 1-D `Tensor` of real labels for the classification task. + predictions: 1-D `Tensor` of predictions for a given classification. num_classes: The possible number of labels the classification task can have. If this value is not provided, it will be calculated using both predictions and labels array. @@ -153,7 +152,7 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, weights = math_ops.cast(weights, dtype) shape = array_ops.stack([num_classes, num_classes]) - indices = array_ops.transpose(array_ops.stack([predictions, labels])) + indices = array_ops.transpose(array_ops.stack([labels, predictions])) values = (array_ops.ones_like(predictions, dtype) if weights is None else weights) cm_sparse = sparse_tensor.SparseTensor(