test fixes?
This commit is contained in:
parent
7f389c2ac3
commit
31077fffee
@ -22,6 +22,7 @@ import numpy as np
|
|||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import confusion_matrix
|
from tensorflow.python.ops import confusion_matrix
|
||||||
@ -188,7 +189,7 @@ class ConfusionMatrixTest(test.TestCase):
|
|||||||
def testLabelsTooLarge(self):
|
def testLabelsTooLarge(self):
|
||||||
labels = np.asarray([1, 1, 0, 3, 5], dtype=np.int32)
|
labels = np.asarray([1, 1, 0, 3, 5], dtype=np.int32)
|
||||||
predictions = np.asarray([2, 1, 0, 2, 2], dtype=np.int32)
|
predictions = np.asarray([2, 1, 0, 2, 2], dtype=np.int32)
|
||||||
with self.assertRaisesOpError("`labels`.*x < y"):
|
with self.assertRaisesWithPredicateMatch(errors_impl.InvalidArgumentError, "`labels`.*x < y"):
|
||||||
self._testConfMatrix(
|
self._testConfMatrix(
|
||||||
labels=labels, predictions=predictions, num_classes=3, truth=None)
|
labels=labels, predictions=predictions, num_classes=3, truth=None)
|
||||||
|
|
||||||
@ -203,7 +204,7 @@ class ConfusionMatrixTest(test.TestCase):
|
|||||||
def testPredictionsTooLarge(self):
|
def testPredictionsTooLarge(self):
|
||||||
labels = np.asarray([1, 1, 0, 2, 2], dtype=np.int32)
|
labels = np.asarray([1, 1, 0, 2, 2], dtype=np.int32)
|
||||||
predictions = np.asarray([2, 1, 0, 3, 5], dtype=np.int32)
|
predictions = np.asarray([2, 1, 0, 3, 5], dtype=np.int32)
|
||||||
with self.assertRaisesOpError("`predictions`.*x < y"):
|
with self.assertRaisesWithPredicateMatch(errors_impl.InvalidArgumentError, "`predictions`.*x < y"):
|
||||||
self._testConfMatrix(
|
self._testConfMatrix(
|
||||||
labels=labels, predictions=predictions, num_classes=3, truth=None)
|
labels=labels, predictions=predictions, num_classes=3, truth=None)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user